diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f18ce75863557f81cafb3154a0b0bdccfab3be1b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+github_page/control.pdf filter=lfs diff=lfs merge=lfs -text
+github_page/p10.png filter=lfs diff=lfs merge=lfs -text
+github_page/p11.png filter=lfs diff=lfs merge=lfs -text
+github_page/p12.png filter=lfs diff=lfs merge=lfs -text
+github_page/p13.png filter=lfs diff=lfs merge=lfs -text
+github_page/p14.png filter=lfs diff=lfs merge=lfs -text
+github_page/p15.png filter=lfs diff=lfs merge=lfs -text
+github_page/p16b.png filter=lfs diff=lfs merge=lfs -text
+github_page/p17.png filter=lfs diff=lfs merge=lfs -text
+github_page/p18.png filter=lfs diff=lfs merge=lfs -text
+github_page/p19.png filter=lfs diff=lfs merge=lfs -text
+github_page/p2.png filter=lfs diff=lfs merge=lfs -text
+github_page/p20.png filter=lfs diff=lfs merge=lfs -text
+github_page/p21.png filter=lfs diff=lfs merge=lfs -text
+github_page/p3.png filter=lfs diff=lfs merge=lfs -text
+github_page/p4.png filter=lfs diff=lfs merge=lfs -text
+github_page/p5.png filter=lfs diff=lfs merge=lfs -text
+github_page/p6.png filter=lfs diff=lfs merge=lfs -text
+github_page/p7.png filter=lfs diff=lfs merge=lfs -text
+github_page/p8.png filter=lfs diff=lfs merge=lfs -text
+github_page/p9.png filter=lfs diff=lfs merge=lfs -text
+github_page/t/op.png filter=lfs diff=lfs merge=lfs -text
+github_page/uc2a.png filter=lfs diff=lfs merge=lfs -text
+github_page/uc2b.png filter=lfs diff=lfs merge=lfs -text
+github_page/uc3.png filter=lfs diff=lfs merge=lfs -text
+github_page/uc4.png filter=lfs diff=lfs merge=lfs -text
+github_page/uc6.png filter=lfs diff=lfs merge=lfs -text
+github_page/uci1.png filter=lfs diff=lfs merge=lfs -text
+github_page/uci2.png filter=lfs diff=lfs merge=lfs -text
+github_page/uci3.png filter=lfs diff=lfs merge=lfs -text
+github_page/uci4.png filter=lfs diff=lfs merge=lfs -text
+output-images/rgb2.png filter=lfs diff=lfs merge=lfs -text
+test_imgs/bird.png filter=lfs diff=lfs merge=lfs -text
+test_imgs/building.png filter=lfs diff=lfs merge=lfs -text
+test_imgs/building2.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..325e5fd26480f88769892b2c263e91daca4ba719
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,143 @@
+.idea/
+
+training/
+lightning_logs/
+image_log/
+
+*.pth
+*.pt
+*.ckpt
+*.safetensors
+
+gradio_pose2image_private.py
+gradio_canny2image_private.py
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 4a970c5f4a764a53c10ea9c06adf1604108a7df0..2bab0977eb779ba65c1ab85cb52a8d190e831f20 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,354 @@
---
title: HumanSD
-emoji: 👀
-colorFrom: red
-colorTo: yellow
+app_file: gradio_humanpose2image.py
sdk: gradio
-sdk_version: 3.45.1
-app_file: app.py
-pinned: false
+sdk_version: 3.44.3
---
+# News: A nightly version of ControlNet 1.1 is released!
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+[ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly) is released. Those new models will be merged to this repo after we make sure that everything is good.
+
+# Below is ControlNet 1.0
+
+Official implementation of [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543).
+
+ControlNet is a neural network structure to control diffusion models by adding extra conditions.
+
+
+
+It copys the weights of neural network blocks into a "locked" copy and a "trainable" copy.
+
+The "trainable" one learns your condition. The "locked" one preserves your model.
+
+Thanks to this, training with small dataset of image pairs will not destroy the production-ready diffusion models.
+
+The "zero convolution" is 1×1 convolution with both weight and bias initialized as zeros.
+
+Before training, all zero convolutions output zeros, and ControlNet will not cause any distortion.
+
+No layer is trained from scratch. You are still fine-tuning. Your original model is safe.
+
+This allows training on small-scale or even personal devices.
+
+This is also friendly to merge/replacement/offsetting of models/weights/blocks/layers.
+
+### FAQ
+
+**Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
+
+**A:** This is not true. [See an explanation here](docs/faq.md).
+
+# Stable Diffusion + ControlNet
+
+By repeating the above simple structure 14 times, we can control stable diffusion in this way:
+
+
+
+In this way, the ControlNet can **reuse** the SD encoder as a **deep, strong, robust, and powerful backbone** to learn diverse controls. Many evidences (like [this](https://jerryxu.net/ODISE/) and [this](https://vpd.ivg-research.xyz/)) validate that the SD encoder is an excellent backbone.
+
+Note that the way we connect layers is computational efficient. The original SD encoder does not need to store gradients (the locked original SD Encoder Block 1234 and Middle). The required GPU memory is not much larger than original SD, although many layers are added. Great!
+
+# Features & News
+
+2023/0/14 - We released [ControlNet 1.1](https://github.com/lllyasviel/ControlNet-v1-1-nightly). Those new models will be merged to this repo after we make sure that everything is good.
+
+2023/03/03 - We released a discussion - [Precomputed ControlNet: Speed up ControlNet by 45%, but is it necessary?](https://github.com/lllyasviel/ControlNet/discussions/216)
+
+2023/02/26 - We released a blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
+
+2023/02/20 - Implementation for non-prompt mode released. See also [Guess Mode / Non-Prompt Mode](#guess-anchor).
+
+2023/02/12 - Now you can play with any community model by [Transferring the ControlNet](https://github.com/lllyasviel/ControlNet/discussions/12).
+
+2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s) or if you want larger batch size.
+
+# Production-Ready Pretrained Models
+
+First create a new conda environment
+
+ conda env create -f environment.yaml
+ conda activate control
+
+All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on.
+
+We provide 9 Gradio apps with these models.
+
+All test images can be found at the folder "test_imgs".
+
+## ControlNet with Canny Edge
+
+Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)
+
+ python gradio_canny2image.py
+
+The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details.
+
+Prompt: "bird"
+
+
+Prompt: "cute dog"
+
+
+## ControlNet with M-LSD Lines
+
+Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection)
+
+ python gradio_hough2image.py
+
+The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details.
+
+Prompt: "room"
+
+
+Prompt: "building"
+
+
+## ControlNet with HED Boundary
+
+Stable Diffusion 1.5 + ControlNet (using soft HED Boundary)
+
+ python gradio_hed2image.py
+
+The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details.
+
+Prompt: "oil painting of handsome old man, masterpiece"
+
+
+Prompt: "Cyberpunk robot"
+
+
+## ControlNet with User Scribbles
+
+Stable Diffusion 1.5 + ControlNet (using Scribbles)
+
+ python gradio_scribble2image.py
+
+Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio.
+
+Prompt: "turtle"
+
+
+Prompt: "hot air balloon"
+
+
+### Interactive Interface
+
+We actually provide an interactive interface
+
+ python gradio_scribble2image_interactive.py
+
+~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap)
+
+The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase.
+
+Prompt: "dog in a room"
+
+
+## ControlNet with Fake Scribbles
+
+Stable Diffusion 1.5 + ControlNet (using fake scribbles)
+
+ python gradio_fake_scribble2image.py
+
+Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images.
+
+Prompt: "bag"
+
+
+Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still seems to work.)
+
+
+## ControlNet with Human Pose
+
+Stable Diffusion 1.5 + ControlNet (using human pose)
+
+ python gradio_pose2image.py
+
+Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you.
+
+Prompt: "Chief in the kitchen"
+
+
+Prompt: "An astronaut on the moon"
+
+
+## ControlNet with Semantic Segmentation
+
+Stable Diffusion 1.5 + ControlNet (using semantic segmentation)
+
+ python gradio_seg2image.py
+
+This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details.
+
+Prompt: "House"
+
+
+Prompt: "River"
+
+
+## ControlNet with Depth
+
+Stable Diffusion 1.5 + ControlNet (using depth map)
+
+ python gradio_depth2image.py
+
+Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2).
+
+Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map.
+
+This is always a strength because if users do not want to preserve more details, they can simply use another SD to post-process an i2i. But if they want to preserve more details, ControlNet becomes their only choice. Again, SD2 uses 64×64 depth, we use 512×512.
+
+Prompt: "Stormtrooper's lecture"
+
+
+## ControlNet with Normal Map
+
+Stable Diffusion 1.5 + ControlNet (using normal map)
+
+ python gradio_normal2image.py
+
+This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling).
+
+Prompt: "Cute toy"
+
+
+Prompt: "Plaster statue of Abraham Lincoln"
+
+
+Compared to depth model, this model seems to be a bit better at preserving the geometry. This is intuitive: minor details are not salient in depth maps, but are salient in normal maps. Below is the depth result with same inputs. You can see that the hairstyle of the man in the input image is modified by depth model, but preserved by the normal model.
+
+Prompt: "Plaster statue of Abraham Lincoln"
+
+
+## ControlNet with Anime Line Drawing
+
+We also trained a relatively simple ControlNet for anime line drawings. This tool may be useful for artistic creations. (Although the image details in the results is a bit modified, since it still diffuse latent images.)
+
+This model is not available right now. We need to evaluate the potential risks before releasing this model. Nevertheless, you may be interested in [transferring the ControlNet to any community model](https://github.com/lllyasviel/ControlNet/discussions/12).
+
+
+
+
+
+# Guess Mode / Non-Prompt Mode
+
+The "guess mode" (or called non-prompt mode) will completely unleash all the power of the very powerful ControlNet encoder.
+
+See also the blog - [Ablation Study: Why ControlNets use deep encoder? What if it was lighter? Or even an MLP?](https://github.com/lllyasviel/ControlNet/discussions/188)
+
+You need to manually check the "Guess Mode" toggle to enable this mode.
+
+In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
+
+**Let's have fun with some very challenging experimental settings!**
+
+**No prompts. No "positive" prompts. No "negative" prompts. No extra caption detector. One single diffusion loop.**
+
+For this mode, we recommend to use 50 steps and guidance scale between 3 and 5.
+
+
+
+No prompts:
+
+
+
+Note that the below example is 768×768. No prompts. No "positive" prompts. No "negative" prompts.
+
+
+
+By tuning the parameters, you can get some very intereting results like below:
+
+
+
+Because no prompt is available, the ControlNet encoder will "guess" what is in the control map. Sometimes the guess result is really interesting. Because diffusion algorithm can essentially give multiple results, the ControlNet seems able to give multiple guesses, like this:
+
+
+
+Without prompt, the HED seems good at generating images look like paintings when the control strength is relatively low:
+
+
+
+The Guess Mode is also supported in [WebUI Plugin](https://github.com/Mikubill/sd-webui-controlnet):
+
+
+
+No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
+
+
+
+Below is another challenging example:
+
+
+
+No prompts. Default WebUI parameters. Pure random results with the seed being 12345. Standard SD1.5. Input scribble is in "test_imgs" folder to reproduce.
+
+
+
+Note that in the guess mode, you will still be able to input prompts. The only difference is that the model will "try harder" to guess what is in the control map even if you do not provide the prompt. Just try it yourself!
+
+Besides, if you write some scripts (like BLIP) to generate image captions from the "guess mode" images, and then use the generated captions as prompts to diffuse again, you will get a SOTA pipeline for fully automatic conditional image generating.
+
+# Combining Multiple ControlNets
+
+ControlNets are composable: more than one ControlNet can be easily composed to multi-condition control.
+
+Right now this feature is in experimental stage in the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet):
+
+
+
+
+
+As long as the models are controlling the same SD, the "boundary" between different research projects does not even exist. This plugin also allows different methods to work together!
+
+# Use ControlNet in Any Community Model (SD1.X)
+
+This is an experimental feature.
+
+[See the steps here](https://github.com/lllyasviel/ControlNet/discussions/12).
+
+Or you may want to use the [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) which is plug-and-play and does not need manual merging.
+
+# Annotate Your Own Data
+
+We provide simple python scripts to process images.
+
+[See a gradio example here](docs/annotator.md).
+
+# Train with Your Own Data
+
+Training a ControlNet is as easy as (or even easier than) training a simple pix2pix.
+
+[See the steps here](docs/train.md).
+
+# Related Resources
+
+Special Thank to the great project - [Mikubill' A1111 Webui Plugin](https://github.com/Mikubill/sd-webui-controlnet) !
+
+We also thank Hysts for making [Hugging Face Space](https://huggingface.co/spaces/hysts/ControlNet) as well as more than 65 models in that amazing [Colab list](https://github.com/camenduru/controlnet-colab)!
+
+Thank haofanwang for making [ControlNet-for-Diffusers](https://github.com/haofanwang/ControlNet-for-Diffusers)!
+
+We also thank all authors for making Controlnet DEMOs, including but not limited to [fffiloni](https://huggingface.co/spaces/fffiloni/ControlNet-Video), [other-model](https://huggingface.co/spaces/hysts/ControlNet-with-other-models), [ThereforeGames](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7784), [RamAnanth1](https://huggingface.co/spaces/RamAnanth1/ControlNet), etc!
+
+Besides, you may also want to read these amazing related works:
+
+[Composer: Creative and Controllable Image Synthesis with Composable Conditions](https://github.com/damo-vilab/composer): A much bigger model to control diffusion!
+
+[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://github.com/TencentARC/T2I-Adapter): A much smaller model to control stable diffusion!
+
+[ControlLoRA: A Light Neural Network To Control Stable Diffusion Spatial Information](https://github.com/HighCWu/ControlLoRA): Implement Controlnet using LORA!
+
+And these amazing recent projects: [InstructPix2Pix Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix), [Pix2pix-zero: Zero-shot Image-to-Image Translation](https://github.com/pix2pixzero/pix2pix-zero), [Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation](https://github.com/MichalGeyer/plug-and-play), [MaskSketch: Unpaired Structure-guided Masked Image Generation](https://arxiv.org/abs/2302.05496), [SEGA: Instructing Diffusion using Semantic Dimensions](https://arxiv.org/abs/2301.12247), [Universal Guidance for Diffusion Models](https://github.com/arpitbansal297/Universal-Guided-Diffusion), [Region-Aware Diffusion for Zero-shot Text-driven Image Editing](https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model), [Domain Expansion of Image Generators](https://arxiv.org/abs/2301.05225), [Image Mixer](https://twitter.com/LambdaAPI/status/1626327289288957956), [MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://multidiffusion.github.io/)
+
+# Citation
+
+ @misc{zhang2023adding,
+ title={Adding Conditional Control to Text-to-Image Diffusion Models},
+ author={Lvmin Zhang and Anyi Rao and Maneesh Agrawala},
+ booktitle={IEEE International Conference on Computer Vision (ICCV)}
+ year={2023},
+ }
+
+[Arxiv Link](https://arxiv.org/abs/2302.05543)
+
+[Supplementary Materials](https://lllyasviel.github.io/misc/202309/cnet_supp.pdf)
diff --git a/__pycache__/aagenerator.cpython-38.pyc b/__pycache__/aagenerator.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9849d946ed6ffc8e681519d81f42162c2d1755b9
Binary files /dev/null and b/__pycache__/aagenerator.cpython-38.pyc differ
diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4b3684bf5087f62eda612c93506e51ccbd85195
Binary files /dev/null and b/__pycache__/config.cpython-38.pyc differ
diff --git a/__pycache__/share.cpython-38.pyc b/__pycache__/share.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a7eb816a3213506c903db1c7013205ab17a2724
Binary files /dev/null and b/__pycache__/share.cpython-38.pyc differ
diff --git a/aa-pose-inference.py b/aa-pose-inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc13efe2228a0bee05289ab004173ab6a4433f94
--- /dev/null
+++ b/aa-pose-inference.py
@@ -0,0 +1,993 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers import AutoTokenizer, PretrainedConfig
+from transformers.utils import ContextManagers
+from PIL import Image
+import PIL
+from PIL import ImageFile
+
+import diffusers
+from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from openclip.training.data import get_wds_dataset, get_wds_dataset_cond
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, \
+ DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, \
+ PNDMScheduler, LMSDiscreteScheduler, UniPCMultistepScheduler
+
+from models.embedder import Embedder
+from pipelines.pipeline_stable_diffusion_mb_downup import StableDiffusionPipeline
+from collections import OrderedDict
+import boto3
+from diffusers.models.controlnet_composer import ControlNetModel
+# from pipelines.pipeline_controlnet_composer import StableDiffusionControlNetPipeline
+from pipelines.pipeline_controlnet_composer_sdxl import StableDiffusionXLControlNetPipeline
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
+import json
+import cv2
+import seaborn as sns
+
+if is_wandb_available():
+ import wandb
+
+logger = get_logger(__name__, log_level="INFO")
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+
+def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10):
+ humansd_skeleton = [
+ [0, 0, 1],
+ [1, 0, 2],
+ [2, 1, 3],
+ [3, 2, 4],
+ [4, 3, 5],
+ [5, 4, 6],
+ [6, 5, 7],
+ [7, 6, 8],
+ [8, 7, 9],
+ [9, 8, 10],
+ [10, 5, 11],
+ [11, 6, 12],
+ [12, 11, 13],
+ [13, 12, 14],
+ [14, 13, 15],
+ [15, 14, 16],
+ ]
+ # humansd_skeleton_width=10
+ humansd_color = sns.color_palette("hls", len(humansd_skeleton))
+
+ def plot_kpts(img_draw, kpts, color, edgs, width):
+ for idx, kpta, kptb in edgs:
+ if kpts[kpta, 2] > mmpose_detection_thresh and \
+ kpts[kptb, 2] > mmpose_detection_thresh:
+ line_color = tuple([int(255 * color_i) for color_i in color[idx]])
+
+ cv2.line(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), (int(kpts[kptb, 0]), int(kpts[kptb, 1])),
+ line_color, width)
+ cv2.circle(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), width // 2, line_color, -1)
+ cv2.circle(img_draw, (int(kpts[kptb, 0]), int(kpts[kptb, 1])), width // 2, line_color, -1)
+
+ if image is None:
+ pose_image = np.zeros((height, width, 3), dtype=np.uint8)
+ else:
+ pose_image = np.array(image, dtype=np.uint8)
+ for person_i in range(len(pose)):
+ if np.sum(pose[person_i]) > 0:
+ plot_kpts(pose_image, pose[person_i], humansd_color, humansd_skeleton, humansd_skeleton_width)
+
+ return pose_image
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ ################################### newly added args ###################################
+ parser.add_argument("--ref_path", type=str, default="/data_laion/alvin/Dataset/evaluation/debug/42361.png")
+ parser.add_argument("--prompt", type=str, default="A person riding skis down a snow covered slope.")
+ parser.add_argument("--t2mn_path", type=str,
+ default="/data_laion/alvin/sd4human/a-ranstart-body-sdv20-v-nd-flaw-avg-copy1-glc-resume288k-512-ft1024/checkpoint-388000")
+ parser.add_argument("--controlnet_model_name_or_path", type=str,
+ default="/data_laion/alvin/sd4human/ctrl-sdxl10-eps-glc-composer-bmn-sum-1024/checkpoint-91000")
+ parser.add_argument('--step_num1', default=50, type=int)
+ parser.add_argument('--step_num2', default=50, type=int)
+ parser.add_argument('--size', default=2048, type=int)
+ parser.add_argument("--pretrained_vae_model_name_or_path", type=str,
+ default='/fsx_laion/alvin/pretrain/sdxl-vae-fp16-fix')
+ parser.add_argument('--normalize_dist', default=True, action="store_false")
+ parser.add_argument('--change_whole_to_body', default=True, action="store_false")
+ parser.add_argument('--off_wa', default=True, action="store_false")
+ parser.add_argument('--flaw', default=True, action="store_false")
+ parser.add_argument("--enable_xformers_memory_efficient_attention", default=True, action="store_false",
+ help="Whether or not to use xformers.")
+ # statistics for three datasets, laion+coyo+getty
+ parser.add_argument("--rgb_mean", type=float, default=0.14654)
+ parser.add_argument("--rgb_std", type=float, default=1.03744)
+ # parser.add_argument("--whole_mean", type=float, default=0.14713)
+ # parser.add_argument("--whole_std", type=float, default=0.96812)
+ parser.add_argument("--whole_mean", type=float, default=-0.2599426086956522)
+ parser.add_argument("--whole_std", type=float, default=1.3836632689065582)
+ parser.add_argument("--body_mean", type=float, default=-0.2481)
+ parser.add_argument("--body_std", type=float, default=1.45647)
+ parser.add_argument("--depth_mean", type=float, default=0.21360)
+ parser.add_argument("--depth_std", type=float, default=1.20629)
+ parser.add_argument("--normal_mean", type=float, default=0.60303)
+ parser.add_argument("--normal_std", type=float, default=0.91429)
+
+ # # statistics for two datasetsm laion+coyo
+ # parser.add_argument("--rgb_mean", type=float, default=0.144028)
+ # parser.add_argument("--rgb_std", type=float, default=1.0420677550094796)
+ # parser.add_argument("--whole_mean", type=float, default=-0.2598586666666667)
+ # parser.add_argument("--whole_std", type=float, default=1.3824869261991977)
+ # parser.add_argument("--body_mean", type=float, default=-0.2481)
+ # parser.add_argument("--body_std", type=float, default=1.45647)
+ # parser.add_argument("--depth_mean", type=float, default=0.22104533333333334)
+ # parser.add_argument("--depth_std", type=float, default=1.2044201368629092)
+ # parser.add_argument("--normal_mean", type=float, default=0.6173293333333333)
+ # parser.add_argument("--normal_std", type=float, default=0.9108628719489077)
+ parser.add_argument('--start', default=0, type=int)
+ parser.add_argument('--end', default=8236, type=int)
+ parser.add_argument("--pretrained_model_name_or_path", type=str,
+ default='/fsx_laion/alvin/pretrain/stable-diffusion-2-base')
+ parser.add_argument("--pretrained_model_name_or_path2", type=str,
+ default='/fsx_laion/alvin/pretrain/stable-diffusion-xl-base-1.0')
+ parser.add_argument('--prediction_type', type=str, default='v_prediction',
+ choices=['epsilon', 'v_prediction', 'target'], help='Select a mode')
+ parser.add_argument('--prediction_type2', type=str, default='epsilon',
+ choices=['epsilon', 'v_prediction', 'target'], help='Select a mode')
+ parser.add_argument("--cond_num", type=int, default=3)
+ parser.add_argument('--fusion', type=str, default="sum")
+ parser.add_argument("--validation_steps", type=int, default=500, )
+ parser.add_argument("--test_data_dir", nargs='+', type=str, default=None, )
+ parser.add_argument('--filter_lowres', default=False, action="store_true")
+ parser.add_argument("--filter_res", type=int)
+ parser.add_argument('--noisy_cond', type=str, default=[], nargs="+", help='add which types of conditions')
+ parser.add_argument("--output_dir2", type=str, default="sd-model-finetuned")
+ parser.add_argument('--cond_reshape2', type=str, choices=['resize', 'vae', 'learn_conv'],
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--inference_folder_name2', type=str,
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--cond_inject2', type=str, choices=['concat', 'spade', 'sum'],
+ help='how to inject the spatial condition')
+ parser.add_argument('--cond_type2', type=str, default=[], nargs="+", help='add which types of conditions')
+ parser.add_argument('--cond_type_test2', type=str, default=None, nargs="+", help='add which types of conditions')
+ parser.add_argument("--resume_from_checkpoint2", type=str, default=None)
+ parser.add_argument('--pred_cond2', default=False, action="store_true")
+ parser.add_argument('--save_cond2', default=False, action="store_true")
+ parser.add_argument('--inference_folder_name', type=str,
+ default="/data_laion/yli12/code_new/ControlNet/output-images",
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--grid_dnc', default=False, action="store_true")
+ parser.add_argument('--pred_cond', default=False, action="store_true")
+ parser.add_argument('--save_cond', default=False, action="store_true")
+ parser.add_argument('--cond_reshape', type=str, choices=['resize', 'vae', 'learn_conv'],
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--cond_inject', type=str, choices=['concat', 'spade', 'sum'],
+ help='how to inject the spatial condition')
+ parser.add_argument('--cond_type', type=str, default=["body", "midas_depth", "normal"], nargs="+",
+ help='add which types of conditions')
+ parser.add_argument('--cond_type_test', type=str, default=None, nargs="+", help='add which types of conditions')
+ parser.add_argument("--embedder_channel", default=4, type=int, help="channel number.")
+ ################################### newly added args ###################################
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ nargs='+',
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=7, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # # Sanity checks
+ # if args.dataset_name is None and args.train_data_dir is None:
+ # raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def main():
+ args = parse_args()
+ if args.change_whole_to_body:
+ args.whole_mean = args.body_mean
+ args.whole_std = args.body_std
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ # logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ # accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ # log_with=args.report_to,
+ # logging_dir=logging_dir,
+ # project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ # text_encoder = CLIPTextModel.from_pretrained(
+ # args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ # )
+ vae = AutoencoderKL.from_pretrained(
+ "/fsx_laion/alvin/pretrain/sd-vae-ft-mse"
+ )
+
+ vae_path = (
+ args.pretrained_model_name_or_path2
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae2 = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+
+ from diffusers.models.unet_2d_condition_multi_branch_downup import UNet2DConditionModel
+ unet_t2mn = UNet2DConditionModel.from_pretrained(args.t2mn_path, subfolder="unet_ema")
+ unet_t2mn.requires_grad_(False)
+
+ unet = diffusers.UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path2, subfolder="unet", revision=args.revision, use_auth_token=True
+ )
+
+ # if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ vae2.requires_grad_(False)
+ unet.requires_grad_(False)
+ unet_t2mn.requires_grad_(False)
+ # text_encoder.requires_grad_(False)
+ controlnet.requires_grad_(False)
+
+ unet.eval()
+ unet_t2mn.eval()
+ controlnet.eval()
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ unet_t2mn.enable_gradient_checkpointing()
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ tf = transforms.Compose(
+ [transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(512),
+ ]
+ )
+
+ from mmpose.apis import MMPoseInferencer
+ # import mmcv
+
+ body_inferencer = MMPoseInferencer(
+ pose2d='/fsx_laion/alvin/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192.py',
+ pose2d_weights='/fsx_laion/alvin/pretrain/ViTPose/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192-ffd48c05_20230314.pth',
+ scope="mmpose"
+ # det_model='/fsx_laion/alvin/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py',
+ # det_weights="/fsx_laion/alvin/pretrain/ViTPose/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
+ )
+
+ input_img = PIL.Image.open(args.ref_path)
+ input_img = tf(input_img)
+
+ image = np.array(input_img.convert("RGB"))
+ img_list = [image]
+ result_generator = body_inferencer(img_list, return_datasample=True)
+ result = next(result_generator)
+
+ # output[img_id]["new_body_bbox"] = result['predictions'][0].pred_instances.bboxes.tolist()
+ # output[img_id]["new_body_bbox_score"] = result['predictions'][0].pred_instances.bbox_scores.tolist()
+ # output[img_id]["new_body_kp"] = result['predictions'][0].pred_instances.keypoints.tolist()
+ # output[img_id]["new_body_kp_score"] = result['predictions'][0].pred_instances.keypoint_scores.tolist()
+
+ kp_coord = result['predictions'][0].pred_instances.keypoints
+ kp_coord_1024 = kp_coord * 2.
+ kp_conf = result['predictions'][0].pred_instances.keypoint_scores
+ kp = np.concatenate([kp_coord, kp_conf[..., np.newaxis]], axis=-1)
+ kp_1024 = np.concatenate([kp_coord_1024, kp_conf[..., np.newaxis]], axis=-1)
+
+ whole_draw = draw_humansd_skeleton(
+ image=None,
+ pose=kp,
+ height=512,
+ width=512,
+ humansd_skeleton_width=10,
+ )
+ whole_image = Image.fromarray(whole_draw)
+
+ whole_draw_1024 = draw_humansd_skeleton(
+ # image=np.array(sample["image"]),
+ image=None,
+ pose=kp_1024,
+ height=1024,
+ width=1024,
+ humansd_skeleton_width=20,
+ )
+ whole_image_1024 = Image.fromarray(whole_draw_1024)
+
+ preprocess = transforms.Compose(
+ [
+ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ preprocess_1024 = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ whole = preprocess(whole_image)
+ whole_1024 = preprocess_1024(whole_image_1024)
+
+ # dataset = CustomDataset(args)
+ # test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
+
+ # lr_scheduler = get_scheduler(
+ # args.lr_scheduler,
+ # optimizer=optimizer,
+ # num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ # num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ # )
+
+ # # Prepare everything with our `accelerator`.
+ unet, unet_t2mn, controlnet = accelerator.prepare(
+ unet, unet_t2mn, controlnet
+ )
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ # text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ vae2.to(accelerator.device, dtype=weight_dtype)
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet_t2mn),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.flaw:
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config, rescale_betas_zero_snr=True,
+ timestep_spacing="trailing")
+ pipeline.scheduler.config.rescale_betas_zero_snr = True
+ pipeline.scheduler.config['rescale_betas_zero_snr'] = True
+ pipeline.scheduler.config.timestep_spacing = "trailing"
+ pipeline.scheduler.config['timestep_spacing'] = "trailing"
+ else:
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ pipeline.scheduler.set_timesteps(args.step_num1)
+
+ pipeline.scheduler.config.prediction_type = args.prediction_type
+ pipeline.scheduler.config['prediction_type'] = args.prediction_type
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=False)
+
+ controlnet = accelerator.unwrap_model(controlnet)
+
+ pipeline2 = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path2,
+ vae=vae2,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ controlnet=controlnet,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ # pipeline2.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config)
+ pipeline2.scheduler = DDPMScheduler.from_config(pipeline2.scheduler.config)
+ pipeline2.scheduler.config.prediction_type = args.prediction_type2
+ pipeline2.scheduler.config['prediction_type'] = args.prediction_type2
+ pipeline2 = pipeline2.to(accelerator.device)
+ pipeline2.set_progress_bar_config(disable=False)
+
+ refiner = DiffusionPipeline.from_pretrained(
+ "/fsx_laion/alvin/pretrain/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=pipeline2.text_encoder_2,
+ vae=pipeline2.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+ )
+ # refiner.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config)
+ refiner.scheduler = DDPMScheduler.from_config(refiner.scheduler.config)
+ refiner.scheduler.config.prediction_type = args.prediction_type2
+ refiner.scheduler.config['prediction_type'] = args.prediction_type2
+ refiner = refiner.to(accelerator.device)
+ refiner.set_progress_bar_config(disable=False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+ pipeline2.enable_xformers_memory_efficient_attention()
+ refiner.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ # step1 = args.t2mn_path.split('/')[-1].split("-")[1]
+ # step2 = args.controlnet_model_name_or_path.split('/')[-1].split("-")[1]
+
+ os.makedirs(args.inference_folder_name, exist_ok=True)
+ # save_path_body = os.path.join(save_path, 'body')
+ # save_path_depth = os.path.join(save_path, 'depth')
+ # save_path_normal = os.path.join(save_path, 'normal')
+ # save_path_rgb1 = os.path.join(save_path, 'rgb1')
+ # save_path_rgb2 = os.path.join(save_path, 'rgb2')
+ # os.makedirs(save_path_body, exist_ok=True)
+ # os.makedirs(save_path_depth, exist_ok=True)
+ # os.makedirs(save_path_normal, exist_ok=True)
+ # os.makedirs(save_path_rgb1, exist_ok=True)
+ # os.makedirs(save_path_rgb2, exist_ok=True)
+
+ batch = {}
+ whole = whole.to(unet.device)
+ whole_1024 = whole_1024.to(unet.device)
+ batch["whole"] = whole.unsqueeze(0)
+ batch["body"] = whole_1024.unsqueeze(0)
+
+ with torch.autocast("cuda"):
+ output = pipeline(
+ args.prompt,
+ height=args.resolution,
+ width=args.resolution,
+ num_inference_steps=args.step_num1,
+ generator=generator,
+ batch=batch,
+ args=args,
+ original_size=(args.size, args.size),
+ guidance_rescale=0.7 if args.flaw else 0.,
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ )
+
+ image = output.images[0]
+ image.save(os.path.join(args.inference_folder_name, "rgb.png"))
+ midas_depth_image = output.midas_depth_image[0]
+ midas_depth_image.save(os.path.join(args.inference_folder_name, "depth.png"))
+ normal_image = output.normal_image[0]
+ normal_image.save(os.path.join(args.inference_folder_name, "normal.png"))
+
+ resize_transform = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ # transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ normalize_transform = transforms.Normalize([0.5], [0.5])
+ # midas_depth_tensor = 2 * (transforms.ToTensor()(midas_depth_image)) - 1
+ midas_depth_tensor = resize_transform(midas_depth_image)
+ # print(midas_depth_tensor.shape)
+ midas_depth_tensor = torch.mean(midas_depth_tensor, dim=0)
+ # print(midas_depth_tensor.shape)
+ depth_min = torch.amin(midas_depth_tensor, dim=[0, 1], keepdim=True)
+ depth_max = torch.amax(midas_depth_tensor, dim=[0, 1], keepdim=True)
+ midas_depth_tensor = (midas_depth_tensor - depth_min) / (depth_max - depth_min)
+ midas_depth_tensor = normalize_transform(midas_depth_tensor.unsqueeze(0).repeat(3, 1, 1))
+ batch["midas_depth"] = midas_depth_tensor.unsqueeze(0).to(unet.device)
+
+ # normal_tensor = 2 * (transforms.ToTensor()(normal_image)) - 1
+ normal_tensor = resize_transform(normal_image)
+ normal_tensor = normal_tensor.clamp(min=0, max=1)
+ normal_tensor = normalize_transform(normal_tensor)
+ batch["normal"] = normal_tensor.unsqueeze(0).to(unet.device)
+
+ body_denormalize = (batch["body"] + 1) / 2.0
+ body_numpy = body_denormalize.cpu().permute(0, 2, 3, 1).float().numpy()[0]
+ body_numpy = (body_numpy * 255).round().astype("uint8")
+ body_pil = Image.fromarray(body_numpy)
+ body_pil.save(os.path.join(args.inference_folder_name, "body.png"))
+ # batch["body"] = batch["body"][0].unsqueeze(0)
+
+ # batch["whole"] = batch["whole_1024"]
+
+ controlnet_image = []
+ for key in ['depth', 'midas_depth', 'normal', 'canny', 'body', 'face', 'hand', 'whole']:
+ if key in args.cond_type:
+ controlnet_image.append(batch[key][0])
+
+ n_steps = args.step_num2
+ high_noise_frac = 0.8
+
+ with torch.autocast("cuda"):
+ output = pipeline2(
+ args.prompt,
+ image=controlnet_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=n_steps,
+ denoising_end=high_noise_frac,
+ output_type="latent",
+ generator=generator,
+ original_size=(args.size, args.size),
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ )
+
+ # image = output.images[0]
+ # image.save(os.path.join(save_path_rgb2, f"{int(id[i_batch]):012d}.jpg"))
+
+ image = output.images
+ image = refiner(
+ args.prompt,
+ # height=1024,
+ # width=1024,
+ num_inference_steps=n_steps,
+ denoising_start=high_noise_frac,
+ image=image,
+ # guidance_scale=args.cfg,
+ generator=generator,
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ ).images[0]
+ image.save(os.path.join(args.inference_folder_name, "rgb2.png"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/aagenerator.py b/aagenerator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddd96a8970b34d097127db664ab244714d7e5a0b
--- /dev/null
+++ b/aagenerator.py
@@ -0,0 +1,981 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers import AutoTokenizer, PretrainedConfig
+from transformers.utils import ContextManagers
+from PIL import Image
+import PIL
+from PIL import ImageFile
+
+import diffusers
+from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from openclip.training.data import get_wds_dataset, get_wds_dataset_cond
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, \
+ DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, \
+ PNDMScheduler, LMSDiscreteScheduler, UniPCMultistepScheduler
+
+from models.embedder import Embedder
+from pipelines.pipeline_stable_diffusion_mb_downup import StableDiffusionPipeline
+from collections import OrderedDict
+import boto3
+from diffusers.models.controlnet_composer import ControlNetModel
+# from pipelines.pipeline_controlnet_composer import StableDiffusionControlNetPipeline
+from pipelines.pipeline_controlnet_composer_sdxl import StableDiffusionXLControlNetPipeline
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
+import json
+import cv2
+import seaborn as sns
+from mmpose.apis import MMPoseInferencer
+from diffusers.models.unet_2d_condition_multi_branch_downup import UNet2DConditionModel
+
+if is_wandb_available():
+ import wandb
+
+logger = get_logger(__name__, log_level="INFO")
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+
+def draw_humansd_skeleton(image, pose, mmpose_detection_thresh=0.3, height=None, width=None, humansd_skeleton_width=10):
+ humansd_skeleton = [
+ [0, 0, 1],
+ [1, 0, 2],
+ [2, 1, 3],
+ [3, 2, 4],
+ [4, 3, 5],
+ [5, 4, 6],
+ [6, 5, 7],
+ [7, 6, 8],
+ [8, 7, 9],
+ [9, 8, 10],
+ [10, 5, 11],
+ [11, 6, 12],
+ [12, 11, 13],
+ [13, 12, 14],
+ [14, 13, 15],
+ [15, 14, 16],
+ ]
+ # humansd_skeleton_width=10
+ humansd_color = sns.color_palette("hls", len(humansd_skeleton))
+
+ def plot_kpts(img_draw, kpts, color, edgs, width):
+ for idx, kpta, kptb in edgs:
+ if kpts[kpta, 2] > mmpose_detection_thresh and \
+ kpts[kptb, 2] > mmpose_detection_thresh:
+ line_color = tuple([int(255 * color_i) for color_i in color[idx]])
+
+ cv2.line(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), (int(kpts[kptb, 0]), int(kpts[kptb, 1])),
+ line_color, width)
+ cv2.circle(img_draw, (int(kpts[kpta, 0]), int(kpts[kpta, 1])), width // 2, line_color, -1)
+ cv2.circle(img_draw, (int(kpts[kptb, 0]), int(kpts[kptb, 1])), width // 2, line_color, -1)
+
+ if image is None:
+ pose_image = np.zeros((height, width, 3), dtype=np.uint8)
+ else:
+ pose_image = np.array(image, dtype=np.uint8)
+ for person_i in range(len(pose)):
+ if np.sum(pose[person_i]) > 0:
+ plot_kpts(pose_image, pose[person_i], humansd_color, humansd_skeleton, humansd_skeleton_width)
+
+ return pose_image
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ ################################### newly added args ###################################
+ parser.add_argument("--ref_path", type=str, default="/data_laion/alvin/Dataset/evaluation/debug/42361.png")
+ parser.add_argument("--prompt", type=str, default="A person riding skis down a snow covered slope.")
+ parser.add_argument("--t2mn_path", type=str,
+ default="/data_laion/alvin/sd4human/ckpts/a-ranstart-body-sdv20-v-nd-flaw-avg-copy1-glc-resume288k-512-ft1024/checkpoint-388000")
+ parser.add_argument("--controlnet_model_name_or_path", type=str,
+ default="/data_laion/alvin/sd4human/ckpts/ctrl-sdxl10-eps-glc-composer-bmn-sum-1024/checkpoint-91000")
+ parser.add_argument('--step_num1', default=50, type=int)
+ parser.add_argument('--step_num2', default=50, type=int)
+ parser.add_argument('--size', default=2048, type=int)
+ parser.add_argument("--pretrained_vae_model_name_or_path", type=str,
+ default='/fsx_laion/alvin/pretrain/sdxl-vae-fp16-fix')
+ parser.add_argument('--normalize_dist', default=True, action="store_false")
+ parser.add_argument('--change_whole_to_body', default=True, action="store_false")
+ parser.add_argument('--off_wa', default=True, action="store_false")
+ parser.add_argument('--flaw', default=True, action="store_false")
+ parser.add_argument("--enable_xformers_memory_efficient_attention", default=True, action="store_false",
+ help="Whether or not to use xformers.")
+ # statistics for three datasets, laion+coyo+getty
+ parser.add_argument("--rgb_mean", type=float, default=0.14654)
+ parser.add_argument("--rgb_std", type=float, default=1.03744)
+ # parser.add_argument("--whole_mean", type=float, default=0.14713)
+ # parser.add_argument("--whole_std", type=float, default=0.96812)
+ parser.add_argument("--whole_mean", type=float, default=-0.2599426086956522)
+ parser.add_argument("--whole_std", type=float, default=1.3836632689065582)
+ parser.add_argument("--body_mean", type=float, default=-0.2481)
+ parser.add_argument("--body_std", type=float, default=1.45647)
+ parser.add_argument("--depth_mean", type=float, default=0.21360)
+ parser.add_argument("--depth_std", type=float, default=1.20629)
+ parser.add_argument("--normal_mean", type=float, default=0.60303)
+ parser.add_argument("--normal_std", type=float, default=0.91429)
+
+ # # statistics for two datasetsm laion+coyo
+ # parser.add_argument("--rgb_mean", type=float, default=0.144028)
+ # parser.add_argument("--rgb_std", type=float, default=1.0420677550094796)
+ # parser.add_argument("--whole_mean", type=float, default=-0.2598586666666667)
+ # parser.add_argument("--whole_std", type=float, default=1.3824869261991977)
+ # parser.add_argument("--body_mean", type=float, default=-0.2481)
+ # parser.add_argument("--body_std", type=float, default=1.45647)
+ # parser.add_argument("--depth_mean", type=float, default=0.22104533333333334)
+ # parser.add_argument("--depth_std", type=float, default=1.2044201368629092)
+ # parser.add_argument("--normal_mean", type=float, default=0.6173293333333333)
+ # parser.add_argument("--normal_std", type=float, default=0.9108628719489077)
+ parser.add_argument('--start', default=0, type=int)
+ parser.add_argument('--end', default=8236, type=int)
+ parser.add_argument("--pretrained_model_name_or_path", type=str,
+ default='/fsx_laion/alvin/pretrain/stable-diffusion-2-base')
+ parser.add_argument("--pretrained_model_name_or_path2", type=str,
+ default='/fsx_laion/alvin/pretrain/stable-diffusion-xl-base-1.0')
+ parser.add_argument('--prediction_type', type=str, default='v_prediction',
+ choices=['epsilon', 'v_prediction', 'target'], help='Select a mode')
+ parser.add_argument('--prediction_type2', type=str, default='epsilon',
+ choices=['epsilon', 'v_prediction', 'target'], help='Select a mode')
+ parser.add_argument("--cond_num", type=int, default=3)
+ parser.add_argument('--fusion', type=str, default="sum")
+ parser.add_argument("--validation_steps", type=int, default=500, )
+ parser.add_argument("--test_data_dir", nargs='+', type=str, default=None, )
+ parser.add_argument('--filter_lowres', default=False, action="store_true")
+ parser.add_argument("--filter_res", type=int)
+ parser.add_argument('--noisy_cond', type=str, default=[], nargs="+", help='add which types of conditions')
+ parser.add_argument("--output_dir2", type=str, default="sd-model-finetuned")
+ parser.add_argument('--cond_reshape2', type=str, choices=['resize', 'vae', 'learn_conv'],
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--inference_folder_name2', type=str,
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--cond_inject2', type=str, choices=['concat', 'spade', 'sum'],
+ help='how to inject the spatial condition')
+ parser.add_argument('--cond_type2', type=str, default=[], nargs="+", help='add which types of conditions')
+ parser.add_argument('--cond_type_test2', type=str, default=None, nargs="+", help='add which types of conditions')
+ parser.add_argument("--resume_from_checkpoint2", type=str, default=None)
+ parser.add_argument('--pred_cond2', default=False, action="store_true")
+ parser.add_argument('--save_cond2', default=False, action="store_true")
+ parser.add_argument('--inference_folder_name', type=str,
+ default="/data_laion/yli12/code_new/ControlNet/output-images",
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--grid_dnc', default=False, action="store_true")
+ parser.add_argument('--pred_cond', default=False, action="store_true")
+ parser.add_argument('--save_cond', default=False, action="store_true")
+ parser.add_argument('--cond_reshape', type=str, choices=['resize', 'vae', 'learn_conv'],
+ help='how to reshape the spatial condition to the same shape as the latent space size')
+ parser.add_argument('--cond_inject', type=str, choices=['concat', 'spade', 'sum'],
+ help='how to inject the spatial condition')
+ parser.add_argument('--cond_type', type=str, default=["body", "midas_depth", "normal"], nargs="+",
+ help='add which types of conditions')
+ parser.add_argument('--cond_type_test', type=str, default=None, nargs="+", help='add which types of conditions')
+ parser.add_argument("--embedder_channel", default=4, type=int, help="channel number.")
+ ################################### newly added args ###################################
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ nargs='+',
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=7, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.self.accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to self.accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # # Sanity checks
+ # if args.dataset_name is None and args.train_data_dir is None:
+ # raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+class Generator:
+ def __init__(self, args=parse_args()):
+ super().__init__()
+ self.args = args
+ if args.change_whole_to_body:
+ args.whole_mean = args.body_mean
+ args.whole_std = args.body_std
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ # logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ # accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit, logging_dir=logging_dir)
+
+ self.accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ # log_with=args.report_to,
+ # logging_dir=logging_dir,
+ # project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(self.accelerator.state, main_process_only=False)
+ if self.accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if self.accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ # text_encoder = CLIPTextModel.from_pretrained(
+ # args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ # )
+ self.vae = AutoencoderKL.from_pretrained(
+ "/fsx_laion/alvin/pretrain/sd-vae-ft-mse"
+ )
+
+ vae_path = (
+ args.pretrained_model_name_or_path2
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ self.vae2 = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+
+ self.unet_t2mn = UNet2DConditionModel.from_pretrained(args.t2mn_path, subfolder="unet_ema")
+ self.unet_t2mn.requires_grad_(False)
+
+ self.unet = diffusers.UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path2, subfolder="unet", revision=args.revision, use_auth_token=True
+ )
+
+ # if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ self.controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ self.accelerator.register_save_state_pre_hook(save_model_hook)
+ self.accelerator.register_load_state_pre_hook(load_model_hook)
+
+ self.vae.requires_grad_(False)
+ self.vae2.requires_grad_(False)
+ self.unet.requires_grad_(False)
+ self.unet_t2mn.requires_grad_(False)
+ # text_encoder.requires_grad_(False)
+ self.controlnet.requires_grad_(False)
+
+ self.unet.eval()
+ self.unet_t2mn.eval()
+ self.controlnet.eval()
+
+ if args.gradient_checkpointing:
+ self.unet.enable_gradient_checkpointing()
+ self.unet_t2mn.enable_gradient_checkpointing()
+ self.controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if self.accelerator.unwrap_model(self.controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {self.accelerator.unwrap_model(self.controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ self.optimizer_class = bnb.optim.AdamW8bit
+ else:
+ self.optimizer_class = torch.optim.AdamW
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if self.accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif self.accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ self.tf = transforms.Compose(
+ [transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(512),
+ ]
+ )
+
+ self.body_inferencer = MMPoseInferencer(
+ pose2d='/fsx_laion/alvin/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192.py',
+ pose2d_weights='/fsx_laion/alvin/pretrain/ViTPose/td-hm_ViTPose-huge-simple_8xb64-210e_coco-256x192-ffd48c05_20230314.pth',
+ scope="mmpose"
+ # det_model='/fsx_laion/alvin/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py',
+ # det_weights="/fsx_laion/alvin/pretrain/ViTPose/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
+ )
+
+ # # Prepare everything with our `accelerator`.
+ self.unet, self.unet_t2mn, self.controlnet = self.accelerator.prepare(
+ self.unet, self.unet_t2mn, self.controlnet
+ )
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ # text_encoder.to(self.accelerator.device, dtype=weight_dtype)
+ self.vae.to(self.accelerator.device, dtype=weight_dtype)
+ self.vae2.to(self.accelerator.device, dtype=weight_dtype)
+
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=self.vae,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ unet=self.accelerator.unwrap_model(self.unet_t2mn),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.flaw:
+ self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config,
+ rescale_betas_zero_snr=True,
+ timestep_spacing="trailing")
+ self.pipeline.scheduler.config.rescale_betas_zero_snr = True
+ self.pipeline.scheduler.config['rescale_betas_zero_snr'] = True
+ self.pipeline.scheduler.config.timestep_spacing = "trailing"
+ self.pipeline.scheduler.config['timestep_spacing'] = "trailing"
+ else:
+ self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
+ self.pipeline.scheduler.set_timesteps(args.step_num1)
+
+ self.pipeline.scheduler.config.prediction_type = args.prediction_type
+ self.pipeline.scheduler.config['prediction_type'] = args.prediction_type
+
+ self.pipeline = self.pipeline.to(self.accelerator.device)
+ self.pipeline.set_progress_bar_config(disable=False)
+
+ self.controlnet = self.accelerator.unwrap_model(self.controlnet)
+
+ self.pipeline2 = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path2,
+ vae=self.vae2,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ unet=self.accelerator.unwrap_model(self.unet),
+ controlnet=self.controlnet,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ # pipeline2.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config)
+ self.pipeline2.scheduler = DDPMScheduler.from_config(self.pipeline2.scheduler.config)
+ self.pipeline2.scheduler.config.prediction_type = args.prediction_type2
+ self.pipeline2.scheduler.config['prediction_type'] = args.prediction_type2
+ self.pipeline2 = self.pipeline2.to(self.accelerator.device)
+ self.pipeline2.set_progress_bar_config(disable=False)
+
+ self.refiner = DiffusionPipeline.from_pretrained(
+ "/fsx_laion/alvin/pretrain/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=self.pipeline2.text_encoder_2,
+ vae=self.pipeline2.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+ )
+ # refiner.scheduler = UniPCMultistepScheduler.from_config(pipeline2.scheduler.config)
+ self.refiner.scheduler = DDPMScheduler.from_config(self.refiner.scheduler.config)
+ self.refiner.scheduler.config.prediction_type = args.prediction_type2
+ self.refiner.scheduler.config['prediction_type'] = args.prediction_type2
+ self.refiner = self.refiner.to(self.accelerator.device)
+ self.refiner.set_progress_bar_config(disable=False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ self.pipeline.enable_xformers_memory_efficient_attention()
+ self.pipeline2.enable_xformers_memory_efficient_attention()
+ self.refiner.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ self.generator = None
+ else:
+ self.generator = torch.Generator(device=self.accelerator.device).manual_seed(args.seed)
+
+ # step1 = args.t2mn_path.split('/')[-1].split("-")[1]
+ # step2 = args.controlnet_model_name_or_path.split('/')[-1].split("-")[1]
+
+ self.preprocess = transforms.Compose(
+ [
+ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ self.preprocess_1024 = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ self.resize_transform = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ # transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ self.normalize_transform = transforms.Normalize([0.5], [0.5])
+
+ os.makedirs(args.inference_folder_name, exist_ok=True)
+
+ def run(self, input_img=None, prompt=None, steps=None):
+ if input_img is None:
+ input_img = PIL.Image.open(self.args.ref_path)
+ else:
+ input_img = PIL.Image.fromarray(input_img)
+ input_img = self.tf(input_img)
+ image = np.array(input_img.convert("RGB"))
+ img_list = [image]
+ result_generator = self.body_inferencer(img_list, return_datasample=True)
+ result = next(result_generator)
+
+ kp_coord = result['predictions'][0].pred_instances.keypoints
+ kp_coord_1024 = kp_coord * 2.
+ kp_conf = result['predictions'][0].pred_instances.keypoint_scores
+ kp = np.concatenate([kp_coord, kp_conf[..., np.newaxis]], axis=-1)
+ kp_1024 = np.concatenate([kp_coord_1024, kp_conf[..., np.newaxis]], axis=-1)
+
+ whole_draw = draw_humansd_skeleton(
+ image=None,
+ pose=kp,
+ height=512,
+ width=512,
+ humansd_skeleton_width=10,
+ )
+ whole_image = Image.fromarray(whole_draw)
+
+ whole_draw_1024 = draw_humansd_skeleton(
+ # image=np.array(sample["image"]),
+ image=None,
+ pose=kp_1024,
+ height=1024,
+ width=1024,
+ humansd_skeleton_width=20,
+ )
+ whole_image_1024 = Image.fromarray(whole_draw_1024)
+
+ whole = self.preprocess(whole_image)
+ whole_1024 = self.preprocess_1024(whole_image_1024)
+
+ batch = {}
+ whole = whole.to(self.unet.device)
+ whole_1024 = whole_1024.to(self.unet.device)
+ batch["whole"] = whole.unsqueeze(0)
+ batch["body"] = whole_1024.unsqueeze(0)
+
+ body_denormalize = (batch["body"] + 1) / 2.0
+ body_numpy = body_denormalize.cpu().permute(0, 2, 3, 1).float().numpy()[0]
+ body_numpy = (body_numpy * 255).round().astype("uint8")
+
+ with torch.autocast("cuda"):
+ output = self.pipeline(
+ prompt or self.args.prompt,
+ height=self.args.resolution,
+ width=self.args.resolution,
+ num_inference_steps=steps or self.args.step_num1,
+ generator=self.generator,
+ batch=batch,
+ args=self.args,
+ original_size=(self.args.size, self.args.size),
+ guidance_rescale=0.7 if self.args.flaw else 0.,
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ )
+
+ rgb = output.images[0]
+ midas_depth_image = output.midas_depth_image[0]
+ normal_image = output.normal_image[0]
+
+ # midas_depth_tensor = 2 * (transforms.ToTensor()(midas_depth_image)) - 1
+ midas_depth_tensor = self.resize_transform(midas_depth_image)
+ # print(midas_depth_tensor.shape)
+ midas_depth_tensor = torch.mean(midas_depth_tensor, dim=0)
+ # print(midas_depth_tensor.shape)
+ depth_min = torch.amin(midas_depth_tensor, dim=[0, 1], keepdim=True)
+ depth_max = torch.amax(midas_depth_tensor, dim=[0, 1], keepdim=True)
+ midas_depth_tensor = (midas_depth_tensor - depth_min) / (depth_max - depth_min)
+ midas_depth_tensor = self.normalize_transform(midas_depth_tensor.unsqueeze(0).repeat(3, 1, 1))
+ batch["midas_depth"] = midas_depth_tensor.unsqueeze(0).to(self.unet.device)
+
+ # normal_tensor = 2 * (transforms.ToTensor()(normal_image)) - 1
+ normal_tensor = self.resize_transform(normal_image)
+ normal_tensor = normal_tensor.clamp(min=0, max=1)
+ normal_tensor = self.normalize_transform(normal_tensor)
+ batch["normal"] = normal_tensor.unsqueeze(0).to(self.unet.device)
+
+ # batch["body"] = batch["body"][0].unsqueeze(0)
+ # batch["whole"] = batch["whole_1024"]
+
+ controlnet_image = []
+ for key in ['depth', 'midas_depth', 'normal', 'canny', 'body', 'face', 'hand', 'whole']:
+ if key in self.args.cond_type:
+ controlnet_image.append(batch[key][0])
+
+ n_steps = steps or self.args.step_num2
+ high_noise_frac = 0.8
+
+ with torch.autocast("cuda"):
+ output2 = self.pipeline2(
+ prompt or self.args.prompt,
+ image=controlnet_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=n_steps,
+ denoising_end=high_noise_frac,
+ output_type="latent",
+ generator=self.generator,
+ original_size=(self.args.size, self.args.size),
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ )
+
+ # image = output.images[0]
+ # image.save(os.path.join(save_path_rgb2, f"{int(id[i_batch]):012d}.jpg"))
+
+ image = output2.images
+ rgb2 = self.refiner(
+ prompt or self.args.prompt,
+ # height=1024,
+ # width=1024,
+ num_inference_steps=n_steps,
+ denoising_start=high_noise_frac,
+ image=image,
+ # guidance_scale=self.args.cfg,
+ generator=self.generator,
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
+ ).images[0]
+
+ return body_numpy, rgb, midas_depth_image, normal_image, rgb2
+
+ def save_images(self, input_img=None, prompt=None):
+ body_numpy, rgb, midas_depth_image, normal_image, rgb2 = self.run(input_img=input_img, prompt=prompt)
+ body_pil = Image.fromarray(body_numpy)
+ body_pil.save(os.path.join(self.args.inference_folder_name, "body.png"))
+
+ rgb.save(os.path.join(self.args.inference_folder_name, "rgb.png"))
+ midas_depth_image.save(os.path.join(self.args.inference_folder_name, "depth.png"))
+ normal_image.save(os.path.join(self.args.inference_folder_name, "normal.png"))
+
+ rgb2.save(os.path.join(self.args.inference_folder_name, "rgb2.png"))
+
+
+if __name__ == "__main__":
+ gen = Generator()
+ gen.save_images()
diff --git a/annotator/__pycache__/util.cpython-38.pyc b/annotator/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..645da9c23f62d4922b09d826528259552daa38d5
Binary files /dev/null and b/annotator/__pycache__/util.cpython-38.pyc differ
diff --git a/annotator/canny/__init__.py b/annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b
--- /dev/null
+++ b/annotator/canny/__init__.py
@@ -0,0 +1,6 @@
+import cv2
+
+
+class CannyDetector:
+ def __call__(self, img, low_threshold, high_threshold):
+ return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/annotator/ckpts/body_pose_model.pth b/annotator/ckpts/body_pose_model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9acb77e68f31906a8875f1daef2f3f7ef94acb1e
--- /dev/null
+++ b/annotator/ckpts/body_pose_model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746
+size 209267595
diff --git a/annotator/ckpts/ckpts.txt b/annotator/ckpts/ckpts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b
--- /dev/null
+++ b/annotator/ckpts/ckpts.txt
@@ -0,0 +1 @@
+Weights here.
\ No newline at end of file
diff --git a/annotator/ckpts/hand_pose_model.pth b/annotator/ckpts/hand_pose_model.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f23ccf3413cc8ac8581a82338a3037bc10d573f0
--- /dev/null
+++ b/annotator/ckpts/hand_pose_model.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b76b00d1750901abd07b9f9d8c98cc3385b8fe834a26d4b4f0aad439e75fc600
+size 147341049
diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a8fc712fba02b033dea13bfe33204b8d3c9139
--- /dev/null
+++ b/annotator/hed/__init__.py
@@ -0,0 +1,96 @@
+# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
+# Please use this implementation in your products
+# This implementation may produce slightly different results from Saining Xie's official implementations,
+# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
+# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
+# and in this way it works better for gradio's RGB protocol
+
+import os
+import cv2
+import torch
+import numpy as np
+
+from einops import rearrange
+from annotator.util import annotator_ckpts_path
+
+
+class DoubleConvBlock(torch.nn.Module):
+ def __init__(self, input_channel, output_channel, layer_number):
+ super().__init__()
+ self.convs = torch.nn.Sequential()
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ for i in range(1, layer_number):
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
+
+ def __call__(self, x, down_sampling=False):
+ h = x
+ if down_sampling:
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
+ for conv in self.convs:
+ h = conv(h)
+ h = torch.nn.functional.relu(h)
+ return h, self.projection(h)
+
+
+class ControlNetHED_Apache2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
+
+ def __call__(self, x):
+ h = x - self.norm
+ h, projection1 = self.block1(h)
+ h, projection2 = self.block2(h, down_sampling=True)
+ h, projection3 = self.block3(h, down_sampling=True)
+ h, projection4 = self.block4(h, down_sampling=True)
+ h, projection5 = self.block5(h, down_sampling=True)
+ return projection1, projection2, projection3, projection4, projection5
+
+
+class HEDdetector:
+ def __init__(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
+ self.netNetwork.load_state_dict(torch.load(modelpath))
+
+ def __call__(self, input_image):
+ assert input_image.ndim == 3
+ H, W, C = input_image.shape
+ with torch.no_grad():
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edges = self.netNetwork(image_hed)
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
+ edges = np.stack(edges, axis=2)
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+ return edge
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
diff --git a/annotator/midas/LICENSE b/annotator/midas/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e
--- /dev/null
+++ b/annotator/midas/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d
--- /dev/null
+++ b/annotator/midas/__init__.py
@@ -0,0 +1,42 @@
+# Midas Depth Estimation
+# From https://github.com/isl-org/MiDaS
+# MIT LICENSE
+
+import cv2
+import numpy as np
+import torch
+
+from einops import rearrange
+from .api import MiDaSInference
+
+
+class MidasDetector:
+ def __init__(self):
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
+
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
+ assert input_image.ndim == 3
+ image_depth = input_image
+ with torch.no_grad():
+ image_depth = torch.from_numpy(image_depth).float().cuda()
+ image_depth = image_depth / 127.5 - 1.0
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+ depth = self.model(image_depth)[0]
+
+ depth_pt = depth.clone()
+ depth_pt -= torch.min(depth_pt)
+ depth_pt /= torch.max(depth_pt)
+ depth_pt = depth_pt.cpu().numpy()
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+ depth_np = depth.cpu().numpy()
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+ z = np.ones_like(x) * a
+ x[depth_pt < bg_th] = 0
+ y[depth_pt < bg_th] = 0
+ normal = np.stack([x, y, z], axis=2)
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
+
+ return depth_image, normal_image
diff --git a/annotator/midas/api.py b/annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8
--- /dev/null
+++ b/annotator/midas/api.py
@@ -0,0 +1,169 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import os
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+from annotator.util import annotator_ckpts_path
+
+
+ISL_PATHS = {
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ with torch.no_grad():
+ prediction = self.model(x)
+ return prediction
+
diff --git a/annotator/midas/midas/__init__.py b/annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/annotator/midas/midas/base_model.py b/annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/annotator/midas/midas/blocks.py b/annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/annotator/midas/midas/dpt_depth.py b/annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/annotator/midas/midas/midas_net.py b/annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/annotator/midas/midas/midas_net_custom.py b/annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/annotator/midas/midas/transforms.py b/annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/annotator/midas/midas/vit.py b/annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/annotator/midas/utils.py b/annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/annotator/mlsd/LICENSE b/annotator/mlsd/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363
--- /dev/null
+++ b/annotator/mlsd/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021-present NAVER Corp.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1860702df6150c5a93c9bb6bf34906a77048c7c
--- /dev/null
+++ b/annotator/mlsd/__init__.py
@@ -0,0 +1,43 @@
+# MLSD Line Detection
+# From https://github.com/navervision/mlsd
+# Apache-2.0 license
+
+import cv2
+import numpy as np
+import torch
+import os
+
+from einops import rearrange
+from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+
+from annotator.util import annotator_ckpts_path
+
+
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
+
+
+class MLSDdetector:
+ def __init__(self):
+ model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ model = MobileV2_MLSD_Large()
+ model.load_state_dict(torch.load(model_path), strict=True)
+ self.model = model.cuda().eval()
+
+ def __call__(self, input_image, thr_v, thr_d):
+ assert input_image.ndim == 3
+ img = input_image
+ img_output = np.zeros_like(img)
+ try:
+ with torch.no_grad():
+ lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+ for line in lines:
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+ except Exception as e:
+ pass
+ return img_output[:, :, 0]
diff --git a/annotator/mlsd/models/mbv2_mlsd_large.py b/annotator/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/annotator/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ if self.upscale:
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [1, 3, 6, 10, 13]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ if pretrained:
+ self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c1, c2, c3, c4, c5 = fpn_features
+ return c1, c2, c3, c4, c5
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Large(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Large, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=False)
+ ## A, B
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+ out_c1= 64, out_c2=64,
+ upscale=False)
+ self.block16 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
+ out_c1= 64, out_c2= 64)
+ self.block18 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block20 = BlockTypeB(128, 64)
+
+ ## A, B, C
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block22 = BlockTypeB(128, 64)
+
+ self.block23 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c1, c2, c3, c4, c5 = self.backbone(x)
+
+ x = self.block15(c4, c5)
+ x = self.block16(x)
+
+ x = self.block17(c3, x)
+ x = self.block18(x)
+
+ x = self.block19(c2, x)
+ x = self.block20(x)
+
+ x = self.block21(c1, x)
+ x = self.block22(x)
+ x = self.block23(x)
+ x = x[:, 7:, :, :]
+
+ return x
\ No newline at end of file
diff --git a/annotator/mlsd/models/mbv2_mlsd_tiny.py b/annotator/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/annotator/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ #[6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+
+ self.fpn_selected = [3, 6, 10]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ #if pretrained:
+ # self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c2, c3, c4 = fpn_features
+ return c2, c3, c4
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Tiny(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Tiny, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=True)
+
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+ out_c1= 64, out_c2=64)
+ self.block13 = BlockTypeB(128, 64)
+
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
+ out_c1= 32, out_c2= 32)
+ self.block15 = BlockTypeB(64, 64)
+
+ self.block16 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c2, c3, c4 = self.backbone(x)
+
+ x = self.block12(c3, c4)
+ x = self.block13(x)
+ x = self.block14(c2, x)
+ x = self.block15(x)
+ x = self.block16(x)
+ x = x[:, 7:, :, :]
+ #print(x.shape)
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+
+ return x
\ No newline at end of file
diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3cf9420a33a4abae27c48ac4b90938c7d63cc3
--- /dev/null
+++ b/annotator/mlsd/utils.py
@@ -0,0 +1,580 @@
+'''
+modified by lihaoweicv
+pytorch version
+'''
+
+'''
+M-LSD
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+'''
+
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn import functional as F
+
+
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+ '''
+ tpMap:
+ center: tpMap[1, 0, :, :]
+ displacement: tpMap[1, 1:5, :, :]
+ '''
+ b, c, h, w = tpMap.shape
+ assert b==1, 'only support bsize==1'
+ displacement = tpMap[:, 1:5, :, :][0]
+ center = tpMap[:, 0, :, :]
+ heat = torch.sigmoid(center)
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+ keep = (hmax == heat).float()
+ heat = heat * keep
+ heat = heat.reshape(-1, )
+
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ ptss = torch.cat((yy, xx),dim=-1)
+
+ ptss = ptss.detach().cpu().numpy()
+ scores = scores.detach().cpu().numpy()
+ displacement = displacement.detach().cpu().numpy()
+ displacement = displacement.transpose((1,2,0))
+ return ptss, scores, displacement
+
+
+def pred_lines(image, model,
+ input_shape=[512, 512],
+ score_thr=0.10,
+ dist_thr=20.0):
+ h, w, _ = image.shape
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+
+ resized_image = resized_image.transpose((2,0,1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().cuda()
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2]
+ end = vmap[:, :, 2:]
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ segments_list = []
+ for center, score in zip(pts, pts_score):
+ y, x = center
+ distance = dist_map[y, x]
+ if score > score_thr and distance > dist_thr:
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ x_start = x + disp_x_start
+ y_start = y + disp_y_start
+ x_end = x + disp_x_end
+ y_end = y + disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ lines = 2 * np.array(segments_list) # 256 > 512
+ lines[:, 0] = lines[:, 0] * w_ratio
+ lines[:, 1] = lines[:, 1] * h_ratio
+ lines[:, 2] = lines[:, 2] * w_ratio
+ lines[:, 3] = lines[:, 3] * h_ratio
+
+ return lines
+
+
+def pred_squares(image,
+ model,
+ input_shape=[512, 512],
+ params={'score': 0.06,
+ 'outside_ratio': 0.28,
+ 'inside_ratio': 0.45,
+ 'w_overlap': 0.0,
+ 'w_degree': 1.95,
+ 'w_length': 0.0,
+ 'w_area': 1.86,
+ 'w_center': 0.14}):
+ '''
+ shape = [height, width]
+ '''
+ h, w, _ = image.shape
+ original_shape = [h, w]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2, 0, 1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().cuda()
+ outputs = model(batch_image)
+
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2] # (x, y)
+ end = vmap[:, :, 2:] # (x, y)
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ junc_list = []
+ segments_list = []
+ for junc, score in zip(pts, pts_score):
+ y, x = junc
+ distance = dist_map[y, x]
+ if score > params['score'] and distance > 20.0:
+ junc_list.append([x, y])
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ d_arrow = 1.0
+ x_start = x + d_arrow * disp_x_start
+ y_start = y + d_arrow * disp_y_start
+ x_end = x + d_arrow * disp_x_end
+ y_end = y + d_arrow * disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ segments = np.array(segments_list)
+
+ ####### post processing for squares
+ # 1. get unique lines
+ point = np.array([[0, 0]])
+ point = point[0]
+ start = segments[:, :2]
+ end = segments[:, 2:]
+ diff = start - end
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+ theta[theta < 0.0] += 180
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+
+ d_quant = 1
+ theta_quant = 2
+ hough[:, 0] //= d_quant
+ hough[:, 1] //= theta_quant
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+ yx_indices = hough[indices, :].astype('int32')
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+
+ acc_map_np = acc_map
+ # acc_map = acc_map[None, :, :, None]
+ #
+ # ### fast suppression using tensorflow op
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+ # _, h, w, _ = acc_map.shape
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
+ # yx = tf.concat([y, x], axis=-1)
+
+ ### fast suppression using pytorch op
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+ _,_, h, w = acc_map.shape
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+ flatten_acc_map = acc_map.reshape([-1, ])
+
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ yx = torch.cat((yy, xx), dim=-1)
+
+ yx = yx.detach().cpu().numpy()
+
+ topk_values = scores.detach().cpu().numpy()
+ indices = idx_map[yx[:, 0], yx[:, 1]]
+ basis = 5 // 2
+
+ merged_segments = []
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+ y, x = yx_pt
+ if max_indice == -1 or value == 0:
+ continue
+ segment_list = []
+ for y_offset in range(-basis, basis + 1):
+ for x_offset in range(-basis, basis + 1):
+ indice = idx_map[y + y_offset, x + x_offset]
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
+ if indice != -1:
+ segment_list.append(segments[indice])
+ if cnt > 1:
+ check_cnt = 1
+ current_hough = hough[indice]
+ for new_indice, new_hough in enumerate(hough):
+ if (current_hough == new_hough).all() and indice != new_indice:
+ segment_list.append(segments[new_indice])
+ check_cnt += 1
+ if check_cnt == cnt:
+ break
+ group_segments = np.array(segment_list).reshape([-1, 2])
+ sorted_group_segments = np.sort(group_segments, axis=0)
+ x_min, y_min = sorted_group_segments[0, :]
+ x_max, y_max = sorted_group_segments[-1, :]
+
+ deg = theta[max_indice]
+ if deg >= 90:
+ merged_segments.append([x_min, y_max, x_max, y_min])
+ else:
+ merged_segments.append([x_min, y_min, x_max, y_max])
+
+ # 2. get intersections
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
+ start = new_segments[:, :2] # (x1, y1)
+ end = new_segments[:, 2:] # (x2, y2)
+ new_centers = (start + end) / 2.0
+ diff = start - end
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+
+ # ax + by = c
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ pre_det = a[:, None] * b[None, :]
+ det = pre_det - np.transpose(pre_det)
+
+ pre_inter_y = a[:, None] * c[None, :]
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+ pre_inter_x = c[:, None] * b[None, :]
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+
+ # 3. get corner information
+ # 3.1 get distance
+ '''
+ dist_segments:
+ | dist(0), dist(1), dist(2), ...|
+ dist_inter_to_segment1:
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+ ...
+ dist_inter_to_semgnet2:
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ ...
+ '''
+
+ dist_inter_to_segment1_start = np.sqrt(
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment1_end = np.sqrt(
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_start = np.sqrt(
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_end = np.sqrt(
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+
+ # sort ascending
+ dist_inter_to_segment1 = np.sort(
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ dist_inter_to_segment2 = np.sort(
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+
+ # 3.2 get degree
+ inter_to_start = new_centers[:, None, :] - inter_pts
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+ inter_to_end = new_centers[None, :, :] - inter_pts
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+
+ '''
+ B -- G
+ | |
+ C -- R
+ B : blue / G: green / C: cyan / R: red
+
+ 0 -- 1
+ | |
+ 3 -- 2
+ '''
+ # rename variables
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+ # sort deg ascending
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+
+ deg_diff_map = np.abs(deg1_map - deg2_map)
+ # we only consider the smallest degree of intersect
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+
+ # define available degree range
+ deg_range = [60, 120]
+
+ corner_dict = {corner_info: [] for corner_info in range(4)}
+ inter_points = []
+ for i in range(inter_pts.shape[0]):
+ for j in range(i + 1, inter_pts.shape[1]):
+ # i, j > line index, always i < j
+ x, y = inter_pts[i, j, :]
+ deg1, deg2 = deg_sort[i, j, :]
+ deg_diff = deg_diff_map[i, j]
+
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+
+ if check_degree and check_distance:
+ corner_info = None
+
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+ corner_info, color_info = 0, 'blue'
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+ corner_info, color_info = 1, 'green'
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+ corner_info, color_info = 2, 'black'
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+ corner_info, color_info = 3, 'cyan'
+ else:
+ corner_info, color_info = 4, 'red' # we don't use it
+ continue
+
+ corner_dict[corner_info].append([x, y, i, j])
+ inter_points.append([x, y])
+
+ square_list = []
+ connect_list = []
+ segments_list = []
+ for corner0 in corner_dict[0]:
+ for corner1 in corner_dict[1]:
+ connect01 = False
+ for corner0_line in corner0[2:]:
+ if corner0_line in corner1[2:]:
+ connect01 = True
+ break
+ if connect01:
+ for corner2 in corner_dict[2]:
+ connect12 = False
+ for corner1_line in corner1[2:]:
+ if corner1_line in corner2[2:]:
+ connect12 = True
+ break
+ if connect12:
+ for corner3 in corner_dict[3]:
+ connect23 = False
+ for corner2_line in corner2[2:]:
+ if corner2_line in corner3[2:]:
+ connect23 = True
+ break
+ if connect23:
+ for corner3_line in corner3[2:]:
+ if corner3_line in corner0[2:]:
+ # SQUARE!!!
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ square_list:
+ order: 0 > 1 > 2 > 3
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ ...
+ connect_list:
+ order: 01 > 12 > 23 > 30
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ ...
+ segments_list:
+ order: 0 > 1 > 2 > 3
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ ...
+ '''
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+
+ def check_outside_inside(segments_info, connect_idx):
+ # return 'outside or inside', min distance, cover_param, peri_param
+ if connect_idx == segments_info[0]:
+ check_dist_mat = dist_inter_to_segment1
+ else:
+ check_dist_mat = dist_inter_to_segment2
+
+ i, j = segments_info
+ min_dist, max_dist = check_dist_mat[i, j, :]
+ connect_dist = dist_segments[connect_idx]
+ if max_dist > connect_dist:
+ return 'outside', min_dist, 0, 1
+ else:
+ return 'inside', min_dist, -1, -1
+
+ top_square = None
+
+ try:
+ map_size = input_shape[0] / 2
+ squares = np.array(square_list).reshape([-1, 4, 2])
+ score_array = []
+ connect_array = np.array(connect_list)
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
+
+ # get degree of corners:
+ squares_rollup = np.roll(squares, 1, axis=1)
+ squares_rolldown = np.roll(squares, -1, axis=1)
+ vec1 = squares_rollup - squares
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+ vec2 = squares_rolldown - squares
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
+
+ # get square score
+ overlap_scores = []
+ degree_scores = []
+ length_scores = []
+
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+
+ # segments: [4, 2]
+ # connects: [4]
+ '''
+
+ ###################################### OVERLAP SCORES
+ cover = 0
+ perimeter = 0
+ # check 0 > 1 > 2 > 3
+ square_length = []
+
+ for start_idx in range(4):
+ end_idx = (start_idx + 1) % 4
+
+ connect_idx = connects[start_idx] # segment idx of segment01
+ start_segments = segments[start_idx]
+ end_segments = segments[end_idx]
+
+ start_point = square[start_idx]
+ end_point = square[end_idx]
+
+ # check whether outside or inside
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+ connect_idx)
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+
+ square_length.append(
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+
+ overlap_scores.append(cover / perimeter)
+ ######################################
+ ###################################### DEGREE SCORES
+ '''
+ deg0 vs deg2
+ deg1 vs deg3
+ '''
+ deg0, deg1, deg2, deg3 = degree
+ deg_ratio1 = deg0 / deg2
+ if deg_ratio1 > 1.0:
+ deg_ratio1 = 1 / deg_ratio1
+ deg_ratio2 = deg1 / deg3
+ if deg_ratio2 > 1.0:
+ deg_ratio2 = 1 / deg_ratio2
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+ ######################################
+ ###################################### LENGTH SCORES
+ '''
+ len0 vs len2
+ len1 vs len3
+ '''
+ len0, len1, len2, len3 = square_length
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
+
+ ######################################
+
+ overlap_scores = np.array(overlap_scores)
+ overlap_scores /= np.max(overlap_scores)
+
+ degree_scores = np.array(degree_scores)
+ # degree_scores /= np.max(degree_scores)
+
+ length_scores = np.array(length_scores)
+
+ ###################################### AREA SCORES
+ area_scores = np.reshape(squares, [-1, 4, 2])
+ area_x = area_scores[:, :, 0]
+ area_y = area_scores[:, :, 1]
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+ area_scores = 0.5 * np.abs(area_scores + correction)
+ area_scores /= (map_size * map_size) # np.max(area_scores)
+ ######################################
+
+ ###################################### CENTER SCORES
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
+ # squares: [n, 4, 2]
+ square_centers = np.mean(squares, axis=1) # [n, 2]
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+ center_scores = center2center / (map_size / np.sqrt(2.0))
+
+ '''
+ score_w = [overlap, degree, area, center, length]
+ '''
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+ score_array = params['w_overlap'] * overlap_scores \
+ + params['w_degree'] * degree_scores \
+ + params['w_area'] * area_scores \
+ - params['w_center'] * center_scores \
+ + params['w_length'] * length_scores
+
+ best_square = []
+
+ sorted_idx = np.argsort(score_array)[::-1]
+ score_array = score_array[sorted_idx]
+ squares = squares[sorted_idx]
+
+ except Exception as e:
+ pass
+
+ '''return list
+ merged_lines, squares, scores
+ '''
+
+ try:
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+ except:
+ new_segments = []
+
+ try:
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ squares = []
+ score_array = []
+
+ try:
+ inter_points = np.array(inter_points)
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ inter_points = []
+
+ return new_segments, squares, score_array, inter_points
diff --git a/annotator/openpose/LICENSE b/annotator/openpose/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd
--- /dev/null
+++ b/annotator/openpose/LICENSE
@@ -0,0 +1,108 @@
+OPENPOSE: MULTIPERSON KEYPOINT DETECTION
+SOFTWARE LICENSE AGREEMENT
+ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
+
+BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
+
+This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
+
+RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
+Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
+non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
+
+CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
+
+COPYRIGHT: The Software is owned by Licensor and is protected by United
+States copyright laws and applicable international treaties and/or conventions.
+
+PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
+
+DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
+
+BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
+
+USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
+
+You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
+
+ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
+
+TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
+
+The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
+
+FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
+
+DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
+
+SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
+
+EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
+
+EXPORT REGULATION: Licensee agrees to comply with any and all applicable
+U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
+
+SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
+
+NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
+
+GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
+
+ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
+
+
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
+
+1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014-2017, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright over
+their contributions to Caffe. The project versioning records all such
+contribution and copyright details. If a contributor wants to further mark
+their specific copyright on a particular contribution, they should indicate
+their copyright solely in the commit message of the change when it is
+committed.
+
+LICENSE
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+CONTRIBUTION AGREEMENT
+
+By contributing to the BVLC/caffe repository through pull-request, comment,
+or otherwise, the contributor releases their content to the
+license and copyright terms herein.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
\ No newline at end of file
diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92e530fd6913a92b1e624d3e334252bcfdba902f
--- /dev/null
+++ b/annotator/openpose/__init__.py
@@ -0,0 +1,49 @@
+# Openpose
+# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
+# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
+# 3rd Edited by ControlNet
+
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+import torch
+import numpy as np
+from . import util
+from .body import Body
+from .hand import Hand
+from annotator.util import annotator_ckpts_path
+
+
+body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
+hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
+
+
+class OpenposeDetector:
+ def __init__(self):
+ body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth")
+ hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth")
+
+ if not os.path.exists(hand_modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(body_model_path, model_dir=annotator_ckpts_path)
+ load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path)
+
+ self.body_estimation = Body(body_modelpath)
+ self.hand_estimation = Hand(hand_modelpath)
+
+ def __call__(self, oriImg, hand=False):
+ oriImg = oriImg[:, :, ::-1].copy()
+ with torch.no_grad():
+ candidate, subset = self.body_estimation(oriImg)
+ canvas = np.zeros_like(oriImg)
+ canvas = util.draw_bodypose(canvas, candidate, subset)
+ if hand:
+ hands_list = util.handDetect(candidate, subset, oriImg)
+ all_hand_peaks = []
+ for x, y, w, is_left in hands_list:
+ peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :])
+ peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x)
+ peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y)
+ all_hand_peaks.append(peaks)
+ canvas = util.draw_handpose(canvas, all_hand_peaks)
+ return canvas, dict(candidate=candidate.tolist(), subset=subset.tolist())
diff --git a/annotator/openpose/__pycache__/__init__.cpython-38.pyc b/annotator/openpose/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c738a75676c80a4b350f2968045433638b765e6c
Binary files /dev/null and b/annotator/openpose/__pycache__/__init__.cpython-38.pyc differ
diff --git a/annotator/openpose/__pycache__/body.cpython-38.pyc b/annotator/openpose/__pycache__/body.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e978aabae0eca387d7d5cec107df39beb555210f
Binary files /dev/null and b/annotator/openpose/__pycache__/body.cpython-38.pyc differ
diff --git a/annotator/openpose/__pycache__/hand.cpython-38.pyc b/annotator/openpose/__pycache__/hand.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c65ade4e28747a7c8a823cab3b1e84f5ecb3b87a
Binary files /dev/null and b/annotator/openpose/__pycache__/hand.cpython-38.pyc differ
diff --git a/annotator/openpose/__pycache__/model.cpython-38.pyc b/annotator/openpose/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82eec9bf06f2b7136945421ad13068197e2da22b
Binary files /dev/null and b/annotator/openpose/__pycache__/model.cpython-38.pyc differ
diff --git a/annotator/openpose/__pycache__/util.cpython-38.pyc b/annotator/openpose/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff4f6a77483f0269317fed94fac4c8ba99420f66
Binary files /dev/null and b/annotator/openpose/__pycache__/util.cpython-38.pyc differ
diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c3cf7a388b4ac81004524e64125e383bdd455bd
--- /dev/null
+++ b/annotator/openpose/body.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+
+from . import util
+from .model import bodypose_model
+
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+
+if __name__ == "__main__":
+ body_estimation = Body('../model/body_pose_model.pth')
+
+ test_image = '../images/ski.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ candidate, subset = body_estimation(oriImg)
+ canvas = util.draw_bodypose(oriImg, candidate, subset)
+ plt.imshow(canvas[:, :, [2, 1, 0]])
+ plt.show()
diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0bf17165ad7eb225332b51f4a2aa16718664b2
--- /dev/null
+++ b/annotator/openpose/hand.py
@@ -0,0 +1,86 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+
+from .model import handpose_model
+from . import util
+
+class Hand(object):
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
+ # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ output = self.model(data).cpu().numpy()
+ # output = self.model(data).numpy()q
+
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ heatmap_avg += heatmap / len(multiplier)
+
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+ # 全部小于阈值
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+
+ y, x = util.npmax(map_ori)
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
+
+if __name__ == "__main__":
+ hand_estimation = Hand('../model/hand_pose_model.pth')
+
+ # test_image = '../images/hand.jpg'
+ test_image = '../images/hand.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ peaks = hand_estimation(oriImg)
+ canvas = util.draw_handpose(oriImg, peaks, True)
+ cv2.imshow('', canvas)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/annotator/openpose/model.py b/annotator/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de
--- /dev/null
+++ b/annotator/openpose/model.py
@@ -0,0 +1,219 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+
+
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
+
+
diff --git a/annotator/openpose/util.py b/annotator/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f91ae0e65abaf0cbd62d803f56498991141e61b
--- /dev/null
+++ b/annotator/openpose/util.py
@@ -0,0 +1,164 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+# transfer caffe model to pytorch which will match the layer name
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+# draw the body keypoint and lims
+def draw_bodypose(canvas, candidate, subset):
+ stickwidth = 4
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+ for i in range(18):
+ for n in range(len(subset)):
+ index = int(subset[n][i])
+ if index == -1:
+ continue
+ x, y = candidate[index][0:2]
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+ for i in range(17):
+ for n in range(len(subset)):
+ index = subset[n][np.array(limbSeq[i]) - 1]
+ if -1 in index:
+ continue
+ cur_canvas = canvas.copy()
+ Y = candidate[index.astype(int), 0]
+ X = candidate[index.astype(int), 1]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
+ # plt.imshow(canvas[:, :, [2, 1, 0]])
+ return canvas
+
+
+# image drawed by opencv is not good.
+def draw_handpose(canvas, all_hand_peaks, show_number=False):
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for peaks in all_hand_peaks:
+ for ie, e in enumerate(edges):
+ if np.sum(np.all(peaks[e], axis=1)==0)==0:
+ x1, y1 = peaks[e[0]]
+ x2, y2 = peaks[e[1]]
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2)
+
+ for i, keyponit in enumerate(peaks):
+ x, y = keyponit
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ if show_number:
+ cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
+ return canvas
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+ for person in subset.astype(int):
+ # if any of three not detected
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+ if not (has_left or has_right):
+ continue
+ hands = []
+ #left hand
+ if has_left:
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+ x1, y1 = candidate[left_shoulder_index][:2]
+ x2, y2 = candidate[left_elbow_index][:2]
+ x3, y3 = candidate[left_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, True])
+ # right hand
+ if has_right:
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+ x1, y1 = candidate[right_shoulder_index][:2]
+ x2, y2 = candidate[right_elbow_index][:2]
+ x3, y3 = candidate[right_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, False])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append([int(x), int(y), int(width), is_left])
+
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
diff --git a/annotator/uniformer/LICENSE b/annotator/uniformer/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..c38dc639e6e238fbf59608f80b3a6ff1928ac429
--- /dev/null
+++ b/annotator/uniformer/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2022 SenseTime X-Lab. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 SenseTime X-Lab.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3364d40997447a4ec15ca7a525a4d0e92ab211bd
--- /dev/null
+++ b/annotator/uniformer/__init__.py
@@ -0,0 +1,27 @@
+# Uniformer
+# From https://github.com/Sense-X/UniFormer
+# # Apache-2.0 license
+
+import os
+
+from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot
+from annotator.uniformer.mmseg.core.evaluation import get_palette
+from annotator.util import annotator_ckpts_path
+
+
+checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth"
+
+
+class UniformerDetector:
+ def __init__(self):
+ modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path)
+ config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py")
+ self.model = init_segmentor(config_file, modelpath).cuda()
+
+ def __call__(self, img):
+ result = inference_segmentor(self.model, img)
+ res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
+ return res_img
diff --git a/annotator/uniformer/configs/_base_/datasets/ade20k.py b/annotator/uniformer/configs/_base_/datasets/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/chase_db1.py b/annotator/uniformer/configs/_base_/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/chase_db1.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'ChaseDB1Dataset'
+data_root = 'data/CHASE_DB1'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (960, 999)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes.py b/annotator/uniformer/configs/_base_/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/cityscapes.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 1024)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 1024),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/train',
+ ann_dir='gtFine/train',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
new file mode 100644
index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
@@ -0,0 +1,35 @@
+_base_ = './cityscapes.py'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (769, 769)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2049, 1025),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/drive.py b/annotator/uniformer/configs/_base_/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/drive.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'DRIVEDataset'
+data_root = 'data/DRIVE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (584, 565)
+crop_size = (64, 64)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/hrf.py b/annotator/uniformer/configs/_base_/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/hrf.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'HRFDataset'
+data_root = 'data/HRF'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (2336, 3504)
+crop_size = (256, 256)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context.py b/annotator/uniformer/configs/_base_/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_context.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
new file mode 100644
index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = 'data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
@@ -0,0 +1,9 @@
+_base_ = './pascal_voc12.py'
+# dataset settings
+data = dict(
+ train=dict(
+ ann_dir=['SegmentationClass', 'SegmentationClassAug'],
+ split=[
+ 'ImageSets/Segmentation/train.txt',
+ 'ImageSets/Segmentation/aug.txt'
+ ]))
diff --git a/annotator/uniformer/configs/_base_/datasets/stare.py b/annotator/uniformer/configs/_base_/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/stare.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'STAREDataset'
+data_root = 'data/STARE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (605, 700)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/default_runtime.py b/annotator/uniformer/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/default_runtime.py
@@ -0,0 +1,14 @@
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook', by_epoch=False),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+cudnn_benchmark = True
diff --git a/annotator/uniformer/configs/_base_/models/ann_r50-d8.py b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ANNHead',
+ in_channels=[1024, 2048],
+ in_index=[2, 3],
+ channels=512,
+ project_channels=256,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='APCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='CCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ recurrence=2,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/cgnet.py b/annotator/uniformer/configs/_base_/models/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/cgnet.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='CGNet',
+ norm_cfg=norm_cfg,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16)),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=256,
+ in_index=2,
+ channels=256,
+ num_convs=0,
+ concat_input=False,
+ dropout_ratio=0,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=[
+ 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
+ 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
+ 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
+ 10.396974, 10.055647
+ ])),
+ # model training and testing settings
+ train_cfg=dict(sampler=None),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/danet_r50-d8.py b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pam_channels=64,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DepthwiseSeparableASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ c1_in_channels=256,
+ c1_channels=48,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DMHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ filter_sizes=(1, 3, 5, 7),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DNLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EMAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=256,
+ ema_channels=512,
+ num_bases=64,
+ num_stages=3,
+ momentum=0.1,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
@@ -0,0 +1,48 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EncHead',
+ in_channels=[512, 1024, 2048],
+ in_index=(1, 2, 3),
+ channels=512,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_se_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fast_scnn.py b/annotator/uniformer/configs/_base_/models/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fast_scnn.py
@@ -0,0 +1,57 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='FastSCNN',
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ norm_cfg=norm_cfg,
+ align_corners=False),
+ decode_head=dict(
+ type='DepthwiseSeparableFCNHead',
+ in_channels=128,
+ channels=128,
+ concat_input=False,
+ num_classes=19,
+ in_index=-1,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ auxiliary_head=[
+ dict(
+ type='FCNHead',
+ in_channels=128,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-2,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ dict(
+ type='FCNHead',
+ in_channels=64,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-3,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_hr18.py b/annotator/uniformer/configs/_base_/models/fcn_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_hr18.py
@@ -0,0 +1,52 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ channels=sum([18, 36, 72, 144]),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
@@ -0,0 +1,45 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
@@ -0,0 +1,51 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=64,
+ in_index=4,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/fpn_r50.py b/annotator/uniformer/configs/_base_/models/fpn_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fpn_r50.py
@@ -0,0 +1,36 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fpn_uniformer.py b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ neck=dict(
+ type='FPN',
+ in_channels=[64, 128, 320, 512],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole')
+)
diff --git a/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='GCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
@@ -0,0 +1,25 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='MobileNetV3',
+ arch='large',
+ out_indices=(1, 3, 16),
+ norm_cfg=norm_cfg),
+ decode_head=dict(
+ type='LRASPPHead',
+ in_channels=(16, 24, 960),
+ in_index=(0, 1, 2),
+ channels=128,
+ input_transform='multiple_select',
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='NLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
@@ -0,0 +1,68 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ channels=sum([18, 36, 72, 144]),
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pointrend_r50.py b/annotator/uniformer/configs/_base_/models/pointrend_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pointrend_r50.py
@@ -0,0 +1,56 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=[
+ dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ dict(
+ type='PointHead',
+ in_channels=[256],
+ in_index=[0],
+ channels=256,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ dropout_ratio=-1,
+ num_classes=19,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(
+ num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
+ test_cfg=dict(
+ mode='whole',
+ subdivision_steps=2,
+ subdivision_num_points=8196,
+ scale_factor=2))
diff --git a/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
@@ -0,0 +1,49 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ mask_size=(97, 97),
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/upernet_r50.py b/annotator/uniformer/configs/_base_/models/upernet_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/upernet_r50.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[256, 512, 1024, 2048],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/upernet_uniformer.py b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
@@ -0,0 +1,43 @@
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[64, 128, 320, 512],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=320,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
\ No newline at end of file
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_160k.py b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
new file mode 100644
index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=160000)
+checkpoint_config = dict(by_epoch=False, interval=16000)
+evaluation = dict(interval=16000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_20k.py b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=20000)
+checkpoint_config = dict(by_epoch=False, interval=2000)
+evaluation = dict(interval=2000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_40k.py b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=40000)
+checkpoint_config = dict(by_epoch=False, interval=4000)
+evaluation = dict(interval=4000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_80k.py b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=80000)
+checkpoint_config = dict(by_epoch=False, interval=8000)
+evaluation = dict(interval=8000, metric='mIoU')
diff --git a/annotator/uniformer/exp/upernet_global_small/config.py b/annotator/uniformer/exp/upernet_global_small/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..01db96bf9b0be531aa0eaf62fee51543712f8670
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/config.py
@@ -0,0 +1,38 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=False
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/run.sh b/annotator/uniformer/exp/upernet_global_small/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9fb22edfa7a32624ea08a63fe7d720c40db3b696
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/run.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+ tools/train.py ${work_path}/config.py \
+ --launcher pytorch \
+ --options model.backbone.pretrained_path='your_model_path/uniformer_small_in1k.pth' \
+ --work-dir ${work_path}/ckpt \
+ 2>&1 | tee -a ${work_path}/log.txt
diff --git a/annotator/uniformer/exp/upernet_global_small/test.sh b/annotator/uniformer/exp/upernet_global_small/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d9a85e7a0d3b7c96b060f473d41254b37a382fcb
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+ tools/test.py ${work_path}/test_config_h32.py \
+ ${work_path}/ckpt/latest.pth \
+ --launcher pytorch \
+ --eval mIoU \
+ 2>&1 | tee -a ${work_path}/log.txt
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_g.py b/annotator/uniformer/exp/upernet_global_small/test_config_g.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43737a98a3b174a9f2fe059c06d511144686459
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_g.py
@@ -0,0 +1,38 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=False,
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_h32.py b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31e3874f76f9f7b089ac8834d85df2441af9b0e
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=True,
+ window_size=32
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_w32.py b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9e06f029e46c14cb9ddb39319cabe86fef9b44
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../../configs/_base_/models/upernet_uniformer.py',
+ '../../configs/_base_/datasets/ade20k.py',
+ '../../configs/_base_/default_runtime.py',
+ '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=True,
+ hybrid=False,
+ window_size=32
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/mmcv/__init__.py b/annotator/uniformer/mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03
--- /dev/null
+++ b/annotator/uniformer/mmcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
diff --git a/annotator/uniformer/mmcv/arraymisc/__init__.py b/annotator/uniformer/mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/annotator/uniformer/mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+
+__all__ = ['quantize', 'dequantize']
diff --git a/annotator/uniformer/mmcv/arraymisc/quantization.py b/annotator/uniformer/mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186
--- /dev/null
+++ b/annotator/uniformer/mmcv/arraymisc/quantization.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/annotator/uniformer/mmcv/cnn/__init__.py b/annotator/uniformer/mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+ DepthwiseSeparableConvModule, GeneralizedAttention,
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+
+__all__ = [
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+]
diff --git a/annotator/uniformer/mmcv/cnn/alexnet.py b/annotator/uniformer/mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/alexnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/__init__.py b/annotator/uniformer/mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+
+__all__ = [
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
+]
diff --git a/annotator/uniformer/mmcv/cnn/bricks/activation.py b/annotator/uniformer/mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab2712287d5ef7be2f079dcb54a94b96394eab5
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/activation.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+]:
+ ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+ """Clamp activation layer.
+
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+
+ def __init__(self, min=-1., max=1.):
+ super(Clamp, self).__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/GELU.png
+
+ Examples::
+
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def forward(self, input):
+ return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg):
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/context_block.py b/annotator/uniformer/mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+
+
+@PLUGIN_LAYERS.register_module()
+class ContextBlock(nn.Module):
+ """ContextBlock module in GCNet.
+
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+ (https://arxiv.org/abs/1904.11492) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ ratio (float): Ratio of channels of transform bottleneck
+ pooling_type (str): Pooling method for context modeling.
+ Options are 'att' and 'avg', stand for attention pooling and
+ average pooling respectively. Default: 'att'.
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
+ Options are 'channels_add', 'channel_mul', stand for channelwise
+ addition and multiplication respectively. Default: ('channel_add',)
+ """
+
+ _abbr_ = 'context_block'
+
+ def __init__(self,
+ in_channels,
+ ratio,
+ pooling_type='att',
+ fusion_types=('channel_add', )):
+ super(ContextBlock, self).__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.in_channels = in_channels
+ self.ratio = ratio
+ self.planes = int(in_channels * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ kaiming_init(self.conv_mask, mode='fan_in')
+ self.conv_mask.inited = True
+
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+
+ return out
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv.py b/annotator/uniformer/mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+
+ def forward(self, x):
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..e60e7e62245071c77b652093fddebff3948d7c3e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy()
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .registry import CONV_LAYERS
+
+
+def conv_ws_2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ eps=1e-5):
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+@CONV_LAYERS.register_module('ConvWS')
+class ConvWS2d(nn.Conv2d):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ eps=1e-5):
+ super(ConvWS2d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.eps = eps
+
+ def forward(self, x):
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.eps)
+
+
+@CONV_LAYERS.register_module(name='ConvAWS')
+class ConvAWS2d(nn.Conv2d):
+ """AWS (Adaptive Weight Standardization)
+
+ This is a variant of Weight Standardization
+ (https://arxiv.org/pdf/1903.10520.pdf)
+ It is used in DetectoRS to avoid NaN
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: True
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.register_buffer('weight_gamma',
+ torch.ones(self.out_channels, 1, 1, 1))
+ self.register_buffer('weight_beta',
+ torch.zeros(self.out_channels, 1, 1, 1))
+
+ def _get_weight(self, weight):
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ weight = (weight - mean) / std
+ weight = self.weight_gamma * weight + self.weight_beta
+ return weight
+
+ def forward(self, x):
+ weight = self._get_weight(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override default load function.
+
+ AWS overrides the function _load_from_state_dict to recover
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
+ weight_beta are found in the checkpoint, this function will return
+ after super()._load_from_state_dict. Otherwise, it will compute the
+ mean and std of the pretrained weights and store them in weight_beta
+ and weight_gamma.
+ """
+
+ self.weight_gamma.data.fill_(-1)
+ local_missing_keys = []
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, local_missing_keys,
+ unexpected_keys, error_msgs)
+ if self.weight_gamma.data.mean() > 0:
+ for k in local_missing_keys:
+ missing_keys.append(k)
+ return
+ weight = self.weight.data
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ self.weight_beta.data.copy_(mean)
+ self.weight_gamma.data.copy_(std)
+ missing_gamma_beta = [
+ k for k in local_missing_keys
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
+ ]
+ for k in missing_gamma_beta:
+ local_missing_keys.remove(k)
+ for k in local_missing_keys:
+ missing_keys.append(k)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .conv_module import ConvModule
+
+
+class DepthwiseSeparableConvModule(nn.Module):
+ """Depthwise separable convolution module.
+
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
+
+ This module can replace a ConvModule with the conv block replaced by two
+ conv block: depthwise conv block and pointwise conv block. The depthwise
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
+ conv block contains pointwise-conv/norm/activation layers. It should be
+ noted that there will be norm/activation layer in the depthwise conv block
+ if `norm_cfg` and `act_cfg` are specified.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
+ pointwise ConvModule. Default: None.
+ act_cfg (dict): Default activation config for both depthwise ConvModule
+ and pointwise ConvModule. Default: dict(type='ReLU').
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ kwargs (optional): Other shared arguments for depthwise and pointwise
+ ConvModule. See ConvModule for ref.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ dw_norm_cfg='default',
+ dw_act_cfg='default',
+ pw_norm_cfg='default',
+ pw_act_cfg='default',
+ **kwargs):
+ super(DepthwiseSeparableConvModule, self).__init__()
+ assert 'groups' not in kwargs, 'groups should not be specified'
+
+ # if norm/activation config of depthwise/pointwise ConvModule is not
+ # specified, use default config.
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+
+ # depthwise convolution
+ self.depthwise_conv = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_cfg=dw_norm_cfg,
+ act_cfg=dw_act_cfg,
+ **kwargs)
+
+ self.pointwise_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ norm_cfg=pw_norm_cfg,
+ act_cfg=pw_act_cfg,
+ **kwargs)
+
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/drop.py b/annotator/uniformer/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b4fccd457a0d51fb10c789df3c8537fe7b67c1
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob=0.1):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob=0.5, inplace=False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg, default_args=None):
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class GeneralizedAttention(nn.Module):
+ """GeneralizedAttention module.
+
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+ (https://arxiv.org/abs/1711.07971) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ spatial_range (int): The spatial range. -1 indicates no spatial range
+ constraint. Default: -1.
+ num_heads (int): The head number of empirical_attention module.
+ Default: 9.
+ position_embedding_dim (int): The position embedding dimension.
+ Default: -1.
+ position_magnitude (int): A multiplier acting on coord difference.
+ Default: 1.
+ kv_stride (int): The feature stride acting on key/value feature map.
+ Default: 2.
+ q_stride (int): The feature stride acting on query feature map.
+ Default: 1.
+ attention_type (str): A binary indicator string for indicating which
+ items in generalized empirical_attention module are used.
+ Default: '1111'.
+
+ - '1000' indicates 'query and key content' (appr - appr) item,
+ - '0100' indicates 'query content and relative position'
+ (appr - position) item,
+ - '0010' indicates 'key content only' (bias - appr) item,
+ - '0001' indicates 'relative position only' (bias - position) item.
+ """
+
+ _abbr_ = 'gen_attention_block'
+
+ def __init__(self,
+ in_channels,
+ spatial_range=-1,
+ num_heads=9,
+ position_embedding_dim=-1,
+ position_magnitude=1,
+ kv_stride=2,
+ q_stride=1,
+ attention_type='1111'):
+
+ super(GeneralizedAttention, self).__init__()
+
+ # hard range means local range for non-local operation
+ self.position_embedding_dim = (
+ position_embedding_dim
+ if position_embedding_dim > 0 else in_channels)
+
+ self.position_magnitude = position_magnitude
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.spatial_range = spatial_range
+ self.kv_stride = kv_stride
+ self.q_stride = q_stride
+ self.attention_type = [bool(int(_)) for _ in attention_type]
+ self.qk_embed_dim = in_channels // num_heads
+ out_c = self.qk_embed_dim * num_heads
+
+ if self.attention_type[0] or self.attention_type[1]:
+ self.query_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.query_conv.kaiming_init = True
+
+ if self.attention_type[0] or self.attention_type[2]:
+ self.key_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.key_conv.kaiming_init = True
+
+ self.v_dim = in_channels // num_heads
+ self.value_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.v_dim * num_heads,
+ kernel_size=1,
+ bias=False)
+ self.value_conv.kaiming_init = True
+
+ if self.attention_type[1] or self.attention_type[3]:
+ self.appr_geom_fc_x = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_x.kaiming_init = True
+
+ self.appr_geom_fc_y = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_y.kaiming_init = True
+
+ if self.attention_type[2]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.appr_bias = nn.Parameter(appr_bias_value)
+
+ if self.attention_type[3]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.geom_bias = nn.Parameter(geom_bias_value)
+
+ self.proj_conv = nn.Conv2d(
+ in_channels=self.v_dim * num_heads,
+ out_channels=in_channels,
+ kernel_size=1,
+ bias=True)
+ self.proj_conv.kaiming_init = True
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ if self.spatial_range >= 0:
+ # only works when non local is after 3*3 conv
+ if in_channels == 256:
+ max_len = 84
+ elif in_channels == 512:
+ max_len = 42
+
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+ local_constraint_map = np.ones(
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+ for iy in range(max_len):
+ for ix in range(max_len):
+ local_constraint_map[
+ iy, ix,
+ max((iy - self.spatial_range) //
+ self.kv_stride, 0):min((iy + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len),
+ max((ix - self.spatial_range) //
+ self.kv_stride, 0):min((ix + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len)] = 0
+
+ self.local_constraint_map = nn.Parameter(
+ torch.from_numpy(local_constraint_map).byte(),
+ requires_grad=False)
+
+ if self.q_stride > 1:
+ self.q_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.q_stride)
+ else:
+ self.q_downsample = None
+
+ if self.kv_stride > 1:
+ self.kv_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.kv_stride)
+ else:
+ self.kv_downsample = None
+
+ self.init_weights()
+
+ def get_position_embedding(self,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ q_stride,
+ kv_stride,
+ device,
+ dtype,
+ feat_dim,
+ wave_length=1000):
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+ h_idxs = h_idxs.view((h, 1)) * q_stride
+
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+ w_idxs = w_idxs.view((w, 1)) * q_stride
+
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+ # (h, h_kv, 1)
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+ h_diff *= self.position_magnitude
+
+ # (w, w_kv, 1)
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+ w_diff *= self.position_magnitude
+
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
+
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+ dim_mat = dim_mat.view((1, 1, -1))
+
+ embedding_x = torch.cat(
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+ embedding_y = torch.cat(
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+ return embedding_x, embedding_y
+
+ def forward(self, x_input):
+ num_heads = self.num_heads
+
+ # use empirical_attention
+ if self.q_downsample is not None:
+ x_q = self.q_downsample(x_input)
+ else:
+ x_q = x_input
+ n, _, h, w = x_q.shape
+
+ if self.kv_downsample is not None:
+ x_kv = self.kv_downsample(x_input)
+ else:
+ x_kv = x_input
+ _, _, h_kv, w_kv = x_kv.shape
+
+ if self.attention_type[0] or self.attention_type[1]:
+ proj_query = self.query_conv(x_q).view(
+ (n, num_heads, self.qk_embed_dim, h * w))
+ proj_query = proj_query.permute(0, 1, 3, 2)
+
+ if self.attention_type[0] or self.attention_type[2]:
+ proj_key = self.key_conv(x_kv).view(
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+ if self.attention_type[1] or self.attention_type[3]:
+ position_embed_x, position_embed_y = self.get_position_embedding(
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+ x_input.device, x_input.dtype, self.position_embedding_dim)
+ # (n, num_heads, w, w_kv, dim)
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ # (n, num_heads, h, h_kv, dim)
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ position_feat_x /= math.sqrt(2)
+ position_feat_y /= math.sqrt(2)
+
+ # accelerate for saliency only
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy = torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, h_kv * w_kv)
+
+ h = 1
+ w = 1
+ else:
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+ if not self.attention_type[0]:
+ energy = torch.zeros(
+ n,
+ num_heads,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ dtype=x_input.dtype,
+ device=x_input.device)
+
+ # attention_type[0]: appr - appr
+ # attention_type[1]: appr - position
+ # attention_type[2]: bias - appr
+ # attention_type[3]: bias - position
+ if self.attention_type[0] or self.attention_type[2]:
+ if self.attention_type[0] and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[0]:
+ energy = torch.matmul(proj_query, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy += torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, 1, h_kv, w_kv)
+
+ if self.attention_type[1] or self.attention_type[3]:
+ if self.attention_type[1] and self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+
+ proj_query_reshape = (proj_query + geom_bias).\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+
+ energy_x = torch.matmul(
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
+ position_feat_x.permute(0, 1, 2, 4, 3))
+ energy_x = energy_x.\
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(
+ proj_query_reshape,
+ position_feat_y.permute(0, 1, 2, 4, 3))
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[1]:
+ proj_query_reshape = proj_query.\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ proj_query_reshape = proj_query_reshape.\
+ permute(0, 1, 3, 2, 4)
+ position_feat_x_reshape = position_feat_x.\
+ permute(0, 1, 2, 4, 3)
+ position_feat_y_reshape = position_feat_y.\
+ permute(0, 1, 2, 4, 3)
+
+ energy_x = torch.matmul(proj_query_reshape,
+ position_feat_x_reshape)
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(proj_query_reshape,
+ position_feat_y_reshape)
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, self.qk_embed_dim, 1).\
+ repeat(n, 1, 1, 1)
+
+ position_feat_x_reshape = position_feat_x.\
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
+
+ position_feat_y_reshape = position_feat_y.\
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+ energy += energy_x + energy_y
+
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+ if self.spatial_range >= 0:
+ cur_local_constraint_map = \
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+ contiguous().\
+ view(1, 1, h*w, h_kv*w_kv)
+
+ energy = energy.masked_fill_(cur_local_constraint_map,
+ float('-inf'))
+
+ attention = F.softmax(energy, 3)
+
+ proj_value = self.value_conv(x_kv)
+ proj_value_reshape = proj_value.\
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+ permute(0, 1, 3, 2)
+
+ out = torch.matmul(attention, proj_value_reshape).\
+ permute(0, 1, 3, 2).\
+ contiguous().\
+ view(n, self.v_dim * self.num_heads, h, w)
+
+ out = self.proj_conv(out)
+
+ # output is downsampled, upsample back to input size
+ if self.q_downsample is not None:
+ out = F.interpolate(
+ out,
+ size=x_input.shape[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ out = self.gamma * out + x_input
+ return out
+
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
+ kaiming_init(
+ m,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='uniform',
+ a=1)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
+
+ Args:
+ bias (float): Bias of the input feature map. Default: 1.0.
+ divisor (float): Divisor of the input feature map. Default: 2.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
+ super(HSigmoid, self).__init__()
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def forward(self, x):
+ x = (x + self.bias) / self.divisor
+
+ return x.clamp_(self.min_value, self.max_value)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/hswish.py b/annotator/uniformer/mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSwish(nn.Module):
+ """Hard Swish Module.
+
+ This module applies the hard swish function:
+
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, inplace=False):
+ super(HSwish, self).__init__()
+ self.act = nn.ReLU6(inplace)
+
+ def forward(self, x):
+ return x * self.act(x + 3) / 6
diff --git a/annotator/uniformer/mmcv/cnn/bricks/non_local.py b/annotator/uniformer/mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+
+import torch
+import torch.nn as nn
+
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+
+
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+ """Basic Non-local module.
+
+ This module is proposed in
+ "Non-local Neural Networks"
+ Paper reference: https://arxiv.org/abs/1711.07971
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ reduction (int): Channel reduction ratio. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+ Default: True.
+ conv_cfg (None | dict): The config dict for convolution layers.
+ If not specified, it will use `nn.Conv2d` for convolution layers.
+ Default: None.
+ norm_cfg (None | dict): The config dict for normalization layers.
+ Default: None. (This parameter is only applicable to conv_out.)
+ mode (str): Options are `gaussian`, `concatenation`,
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+ """
+
+ def __init__(self,
+ in_channels,
+ reduction=2,
+ use_scale=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(_NonLocalNd, self).__init__()
+ self.in_channels = in_channels
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.inter_channels = max(in_channels // reduction, 1)
+ self.mode = mode
+
+ if mode not in [
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+ ]:
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+ f"'embedded_gaussian' or 'dot_product', but got "
+ f'{mode} instead.')
+
+ # g, theta, phi are defaulted as `nn.ConvNd`.
+ # Here we use ConvModule for potential usage.
+ self.g = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.conv_out = ConvModule(
+ self.inter_channels,
+ self.in_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ if self.mode != 'gaussian':
+ self.theta = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.phi = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+
+ if self.mode == 'concatenation':
+ self.concat_project = ConvModule(
+ self.inter_channels * 2,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ act_cfg=dict(type='ReLU'))
+
+ self.init_weights(**kwargs)
+
+ def init_weights(self, std=0.01, zeros_init=True):
+ if self.mode != 'gaussian':
+ for m in [self.g, self.theta, self.phi]:
+ normal_init(m.conv, std=std)
+ else:
+ normal_init(self.g.conv, std=std)
+ if zeros_init:
+ if self.conv_out.norm_cfg is None:
+ constant_init(self.conv_out.conv, 0)
+ else:
+ constant_init(self.conv_out.norm, 0)
+ else:
+ if self.conv_out.norm_cfg is None:
+ normal_init(self.conv_out.conv, std=std)
+ else:
+ normal_init(self.conv_out.norm, std=std)
+
+ def gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def dot_product(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+
+ def concatenation(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ h = theta_x.size(2)
+ w = phi_x.size(3)
+ theta_x = theta_x.repeat(1, 1, 1, w)
+ phi_x = phi_x.repeat(1, 1, h, 1)
+
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
+ pairwise_weight = self.concat_project(concat_feature)
+ n, _, h, w = pairwise_weight.size()
+ pairwise_weight = pairwise_weight.view(n, h, w)
+ pairwise_weight /= pairwise_weight.shape[-1]
+
+ return pairwise_weight
+
+ def forward(self, x):
+ # Assume `reduction = 1`, then `inter_channels = C`
+ # or `inter_channels = C` when `mode="gaussian"`
+
+ # NonLocal1d x: [N, C, H]
+ # NonLocal2d x: [N, C, H, W]
+ # NonLocal3d x: [N, C, T, H, W]
+ n = x.size(0)
+
+ # NonLocal1d g_x: [N, H, C]
+ # NonLocal2d g_x: [N, HxW, C]
+ # NonLocal3d g_x: [N, TxHxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ pairwise_func = getattr(self, self.mode)
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # NonLocal1d y: [N, H, C]
+ # NonLocal2d y: [N, HxW, C]
+ # NonLocal3d y: [N, TxHxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # NonLocal1d y: [N, C, H]
+ # NonLocal2d y: [N, C, H, W]
+ # NonLocal3d y: [N, C, T, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ output = x + self.conv_out(y)
+
+ return output
+
+
+class NonLocal1d(_NonLocalNd):
+ """1D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv1d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv1d'),
+ **kwargs):
+ super(NonLocal1d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+@PLUGIN_LAYERS.register_module()
+class NonLocal2d(_NonLocalNd):
+ """2D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv2d').
+ """
+
+ _abbr_ = 'nonlocal_block'
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv2d'),
+ **kwargs):
+ super(NonLocal2d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+class NonLocal3d(_NonLocalNd):
+ """3D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv3d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv3d'),
+ **kwargs):
+ super(NonLocal3d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/norm.py b/annotator/uniformer/mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..408f4b42731b19a3beeef68b6a5e610d0bbc18b3
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import is_tuple_of
+from annotator.uniformer.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+ """Build normalization layer.
+
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+
+ Returns:
+ (str, nn.Module): The first element is the layer name consisting of
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
+
+
+def is_norm(layer, exclude=None):
+ """Check if a layer is a normalization layer.
+
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+
+ if exclude and isinstance(layer, exclude):
+ return False
+
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/padding.py b/annotator/uniformer/mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg, *args, **kwargs):
+ """Build padding layer.
+
+ Args:
+ cfg (None or dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/plugin.py b/annotator/uniformer/mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,88 @@
+import inspect
+import platform
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+
+ Modified from `inflection lib
+ `_.
+
+ Example::
+
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ else:
+ return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg, postfix='', **kwargs):
+ """Build plugin layer.
+
+ Args:
+ cfg (None or dict): cfg should contain:
+ type (str): identify plugin layer type.
+ layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+
+ Returns:
+ tuple[str, nn.Module]:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created plugin layer
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ layer = plugin_layer(**kwargs, **cfg_)
+
+ return name, layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/registry.py b/annotator/uniformer/mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..39eabc58db4b5954478a2ac1ab91cea5e45ab055
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/annotator/uniformer/mmcv/cnn/bricks/scale.py b/annotator/uniformer/mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+ """A learnable scale parameter.
+
+ This layer scales the input by a learnable factor. It multiplies a
+ learnable scale parameter of shape (1,) with input of any shape.
+
+ Args:
+ scale (float): Initial value of scale factor. Default: 1.0
+ """
+
+ def __init__(self, scale=1.0):
+ super(Scale, self).__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+ def forward(self, x):
+ return x * self.scale
diff --git a/annotator/uniformer/mmcv/cnn/bricks/swish.py b/annotator/uniformer/mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class Swish(nn.Module):
+ """Swish Module.
+
+ This module applies the swish function:
+
+ .. math::
+ Swish(x) = x * Sigmoid(x)
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self):
+ super(Swish, self).__init__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/transformer.py b/annotator/uniformer/mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61ae0dd941a7be00b3e41a3de833ec50470a45f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from annotator.uniformer.mmcv.utils import build_from_cfg
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+ from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+
+
+def build_attention(cfg, default_args=None):
+ """Builder for attention."""
+ return build_from_cfg(cfg, ATTENTION, default_args)
+
+
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
+def build_transformer_layer(cfg, default_args=None):
+ """Builder for transformer layer."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+
+
+def build_transformer_layer_sequence(cfg, default_args=None):
+ """Builder for transformer encoder and transformer decoder."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+
+
+@ATTENTION.register_module()
+class MultiheadAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ super(MultiheadAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ')
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = batch_first
+
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
+ super(FFN, self).__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class BaseTransformerLayer(BaseModule):
+ """Base `TransformerLayer` for vision transformer.
+
+ It can be built from `mmcv.ConfigDict` and support more flexible
+ customization, for example, using any number of `FFN or LN ` and
+ use different kinds of `attention` by specifying a list of `ConfigDict`
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+ when you specifying `norm` as the first element of `operation_order`.
+ More details about the `prenorm`: `On Layer Normalization in the
+ Transformer Architecture `_ .
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for `self_attention` or `cross_attention` modules,
+ The order of the configs in the list should be consistent with
+ corresponding attentions in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config. Default: None.
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Support `prenorm` when you specifying first element as `norm`.
+ Default:None.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ """
+
+ def __init__(self,
+ attn_cfgs=None,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ operation_order=None,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ')
+ ffn_cfgs[new_name] = kwargs[ori_name]
+
+ super(BaseTransformerLayer, self).__init__(init_cfg)
+
+ self.batch_first = batch_first
+
+ assert set(operation_order) & set(
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
+ set(operation_order), f'The operation_order of' \
+ f' {self.__class__.__name__} should ' \
+ f'contains all four operation type ' \
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
+ num_attn = operation_order.count('self_attn') + operation_order.count(
+ 'cross_attn')
+ if isinstance(attn_cfgs, dict):
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+ else:
+ assert num_attn == len(attn_cfgs), f'The length ' \
+ f'of attn_cfg {num_attn} is ' \
+ f'not consistent with the number of attention' \
+ f'in operation_order {operation_order}.'
+
+ self.num_attn = num_attn
+ self.operation_order = operation_order
+ self.norm_cfg = norm_cfg
+ self.pre_norm = operation_order[0] == 'norm'
+ self.attentions = ModuleList()
+
+ index = 0
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
+ attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
+ self.attentions.append(attention)
+ index += 1
+
+ self.embed_dims = self.attentions[0].embed_dims
+
+ self.ffns = ModuleList()
+ num_ffns = operation_order.count('ffn')
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+ self.ffns.append(
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
+
+ self.norms = ModuleList()
+ num_norms = operation_order.count('norm')
+ for _ in range(num_norms):
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerDecoderLayer`.
+
+ **kwargs contains some specific arguments of attentions.
+
+ Args:
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor] | None): 2D Tensor used in
+ calculation of corresponding attention. The length of
+ it should equal to the number of `attention` in
+ `operation_order`. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in `self_attn` layer.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+
+ norm_index = 0
+ attn_index = 0
+ ffn_index = 0
+ identity = query
+ if attn_masks is None:
+ attn_masks = [None for _ in range(self.num_attn)]
+ elif isinstance(attn_masks, torch.Tensor):
+ attn_masks = [
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+ ]
+ warnings.warn(f'Use same attn_mask in all attentions in '
+ f'{self.__class__.__name__} ')
+ else:
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
+ f'attn_masks {len(attn_masks)} must be equal ' \
+ f'to the number of attention in ' \
+ f'operation_order {self.num_attn}'
+
+ for layer in self.operation_order:
+ if layer == 'self_attn':
+ temp_key = temp_value = query
+ query = self.attentions[attn_index](
+ query,
+ temp_key,
+ temp_value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=query_key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'norm':
+ query = self.norms[norm_index](query)
+ norm_index += 1
+
+ elif layer == 'cross_attn':
+ query = self.attentions[attn_index](
+ query,
+ key,
+ value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'ffn':
+ query = self.ffns[ffn_index](
+ query, identity if self.pre_norm else None)
+ ffn_index += 1
+
+ return query
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class TransformerLayerSequence(BaseModule):
+ """Base class for TransformerEncoder and TransformerDecoder in vision
+ transformer.
+
+ As base-class of Encoder and Decoder in vision transformer.
+ Support customization such as specifying different kind
+ of `transformer_layer` in `transformer_coder`.
+
+ Args:
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+ it would be repeated `num_layer` times to a
+ list[`mmcv.ConfigDict`]. Default: None.
+ num_layers (int): The number of `TransformerLayer`. Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+ super(TransformerLayerSequence, self).__init__(init_cfg)
+ if isinstance(transformerlayers, dict):
+ transformerlayers = [
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
+ ]
+ else:
+ assert isinstance(transformerlayers, list) and \
+ len(transformerlayers) == num_layers
+ self.num_layers = num_layers
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
+ self.embed_dims = self.layers[0].embed_dims
+ self.pre_norm = self.layers[0].pre_norm
+
+ def forward(self,
+ query,
+ key,
+ value,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_queries, bs, embed_dims)`.
+ key (Tensor): The key tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
+ which is used in calculation of corresponding attention in
+ operation_order. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in self-attention
+ Default: None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: results with shape [num_queries, bs, embed_dims].
+ """
+ for layer in self.layers:
+ query = layer(
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_masks=attn_masks,
+ query_key_padding_mask=query_key_padding_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ return query
diff --git a/annotator/uniformer/mmcv/cnn/bricks/upsample.py b/annotator/uniformer/mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super(PixelShufflePack, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+
+ def forward(self, x):
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/wrappers.py b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold):
+ return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, new_shape):
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+
+ @staticmethod
+ def backward(ctx, grad):
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+ def forward(self, x):
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
diff --git a/annotator/uniformer/mmcv/cnn/builder.py b/annotator/uniformer/mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/annotator/uniformer/mmcv/cnn/resnet.py b/annotator/uniformer/mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/resnet.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ super(BasicBlock, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(block,
+ inplanes,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ style='pytorch',
+ with_cp=False):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ with_cp=False):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages]
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/annotator/uniformer/mmcv/cnn/utils/__init__.py b/annotator/uniformer/mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/annotator/uniformer/mmcv/cnn/utils/flops_counter.py b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10af5feca7f4b8c0ba359b7b1c826f754e048be
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,599 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import annotator.uniformer.mmcv as mmcv
+
+
+def get_model_complexity_info(model,
+ input_shape,
+ print_per_layer_stat=True,
+ as_strings=True,
+ input_constructor=None,
+ flush=False,
+ ost=sys.stdout):
+ """Get complexity information of a model.
+
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
+ ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+
+ _ = flops_model(batch)
+
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops, units='GFLOPs', precision=2):
+ """Convert FLOPs number into a string.
+
+ Note that Here we take a multiply-add counts as one FLOP.
+
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted FLOPs number with units.
+
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params, units=None, precision=2):
+ """Convert parameter number into a string.
+
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted parameter number with units.
+
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+
+
+def print_model_with_flops(model,
+ total_flops,
+ total_params,
+ units='GFLOPs',
+ precision=3,
+ ost=sys.stdout,
+ flush=False):
+ """Print a model with FLOPs for each layer.
+
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+
+ Example:
+ >>> class ExampleModel(nn.Module):
+
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ '{:.3%} Params'.format(accumulated_num_params / total_params),
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr()
+ ])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+ """Calculate parameter number of a model.
+
+ Args:
+ model (nn.module): The model for parameter number calculation.
+
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+
+
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__(
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
+ net_main_module)
+
+ net_main_module.reset_flops_count()
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self):
+ """Compute average FLOPs cost.
+
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self):
+ """Activate the computation of mean flops consumption per image.
+
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+
+ def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+
+ module.__flops_handle__ = handle
+
+ self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self):
+ """Stop computing the mean flops consumption per image.
+
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+ """Reset statistics computed so far.
+
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input.shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+
+
+def norm_flops_counter_hook(module, input, output):
+ input = input[0]
+
+ batch_flops = np.prod(input.shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ input_height, input_width = input.shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_height
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * int(np.prod(output_dims))
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = len(input)
+ else:
+ pass
+ print('Warning! No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module):
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ print('Warning: variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module):
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+
+
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def get_modules_mapping():
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv, bn):
+ """Fuse conv and bn into one module.
+
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_conv_bn(module):
+ """Recursively fuse conv and bn in a module.
+
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+
+ Args:
+ module (nn.Module): Module to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/annotator/uniformer/mmcv/cnn/utils/sync_bn.py b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78f39181d75bb85c53e8c7c8eaf45690e9f0bee
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,59 @@
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+
+
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input):
+ return
+
+
+def revert_sync_batchnorm(module):
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(mmcv, 'ops'):
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/annotator/uniformer/mmcv/cnn/utils/weight_init.py b/annotator/uniformer/mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..287a1d0bffe26e023029d48634d9b761deda7ba4
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,684 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module, init_info):
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob):
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+
+
+def _get_bases_name(m):
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit(object):
+
+ def __init__(self, *, bias=0, bias_prob=None, layer=None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, val, **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, gain=1, distribution='normal', **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self, mean=0, std=1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, a=0, b=1, **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ distribution='normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+
+ def __call__(self, module):
+ super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit(object):
+ """Initialize module by loading a pretrained model.
+
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+
+ def __init__(self, checkpoint, prefix=None, map_location=None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+
+ def __call__(self, module):
+ from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('mmcv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+
+
+def _initialize(module, cfg, wholemodule=False):
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+
+
+def _initialize_override(module, override, cfg):
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+
+ override = [override] if isinstance(override, dict) else override
+
+ for override_ in override:
+
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+
+
+def initialize(module, init_cfg):
+ """Initialize a module.
+
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/annotator/uniformer/mmcv/cnn/vgg.py b/annotator/uniformer/mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/vgg.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+from .utils import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes, out_planes, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation)
+
+
+def make_vgg_layer(inplanes,
+ planes,
+ num_blocks,
+ dilation=1,
+ with_bn=False,
+ ceil_mode=False):
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+class VGG(nn.Module):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+
+ def __init__(self,
+ depth,
+ with_bn=False,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3, 4),
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ ceil_mode=False,
+ with_last_pool=True):
+ super(VGG, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VGG, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/annotator/uniformer/mmcv/engine/__init__.py b/annotator/uniformer/mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/annotator/uniformer/mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
+]
diff --git a/annotator/uniformer/mmcv/engine/test.py b/annotator/uniformer/mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbeef271db634ec2dadfda3bc0b5ef9c7a677ff
--- /dev/null
+++ b/annotator/uniformer/mmcv/engine/test.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import torch
+import torch.distributed as dist
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model, data_loader):
+ """Test model with a single gpu.
+
+ This method tests model with a single gpu and displays test progress bar.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results under cpu mode.
+
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_result = mmcv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/annotator/uniformer/mmcv/fileio/__init__.py b/annotator/uniformer/mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
+]
diff --git a/annotator/uniformer/mmcv/fileio/file_client.py b/annotator/uniformer/mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..950f0c1aeab14b8e308a7455ccd64a95b5d98add
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/file_client.py
@@ -0,0 +1,1148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterable, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.utils.misc import has_method
+from annotator.uniformer.mmcv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ """
+
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download a file from ``filepath`` and return a temporary path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ self.db_path = str(db_path)
+ self._client = lmdb.open(
+ self.db_path,
+ readonly=readonly,
+ lock=lock,
+ readahead=readahead,
+ **kwargs)
+
+ def get(self, filepath):
+ """Get values according to the filepath.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ filepath = str(filepath)
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ _allow_symlink = True
+
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, 'r', encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+
+ @contextmanager
+ def get_local_path(self, filepath: str) -> Iterable[str]:
+ """Download a file from ``filepath``.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+
+class FileClient:
+ """A general file client to access files in different backends.
+
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+ # This collection is used to record the overridden backends, and when a
+ # backend appears in the collection, the singleton pattern is disabled for
+ # that backend, because if the singleton pattern is used, then the object
+ # returned will be the backend before overwriting
+ _overridden_backends = set()
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+ _overridden_prefixes = set()
+
+ _instances = {}
+
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+
+ # if a backend was overridden, it will create a new object
+ if (arg_key in cls._instances
+ and backend not in cls._overridden_backends
+ and prefix not in cls._overridden_prefixes):
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+ cls._instances[arg_key] = _instance
+
+ return _instance
+
+ @property
+ def name(self):
+ return self.client.name
+
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://'
+ else ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+
+ if name in cls._backends and force:
+ cls._overridden_backends.add(name)
+ cls._backends[name] = backend
+
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ cls._overridden_prefixes.add(prefix)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+
+ This method can be used as a normal class method or a decorator.
+
+ .. code-block:: python
+
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ FileClient.register_backend('new', NewBackend)
+
+ or
+
+ .. code-block:: python
+
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+
+ return _register
+
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download data from ``filepath`` and write the data to local path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+
+ Args:
+ filepath (str or Path): Path to be read data.
+
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/__init__.py b/annotator/uniformer/mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/annotator/uniformer/mmcv/fileio/handlers/base.py b/annotator/uniformer/mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode='r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/json_handler.py b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+ """Set default json values for non-serializable values.
+
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+ str_like = False
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/io.py b/annotator/uniformer/mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/io.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+
+from ..utils import is_list_of, is_str
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+
+def load(file, file_format=None, file_client_args=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler, file_formats):
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats, **kwargs):
+
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/annotator/uniformer/mmcv/fileio/parse.py b/annotator/uniformer/mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/parse.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+
+from .file_client import FileClient
+
+
+def list_from_file(filename,
+ prefix='',
+ offset=0,
+ max_num=0,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a list of strings.
+
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename,
+ key_type=str,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/annotator/uniformer/mmcv/image/__init__.py b/annotator/uniformer/mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_lighting, adjust_sharpness, auto_contrast,
+ clahe, imdenormalize, imequalize, iminvert,
+ imnormalize, imnormalize_, lut_transform, posterize,
+ solarize)
+
+__all__ = [
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
+]
diff --git a/annotator/uniformer/mmcv/image/colorspace.py b/annotator/uniformer/mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/colorspace.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+
+def imconvert(img, src, dst):
+ """Convert an image from the src colorspace to dst colorspace.
+
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+
+def bgr2gray(img, keepdim=False):
+ """Convert a BGR image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def rgb2gray(img, keepdim=False):
+ """Convert a RGB image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def gray2bgr(img):
+ """Convert a grayscale image to BGR image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+
+
+def gray2rgb(img):
+ """Convert a grayscale image to RGB image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def convert_color_factory(src, dst):
+
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+ def convert_color(img):
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+
+ Args:
+ img (ndarray or str): The input image.
+
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+
+ return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/annotator/uniformer/mmcv/image/geometric.py b/annotator/uniformer/mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/geometric.py
@@ -0,0 +1,728 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+
+if Image is not None:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+
+
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+ new_size = _scale_size((w, h), scale_factor)
+
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+
+
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+
+
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the
+ last value on the edge. For example, padding [1, 2, 3, 4]
+ with 2 elements on both sides in reflect mode will result
+ in [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with
+ 2 elements on both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ ndarray: The padded image.
+ """
+
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
+
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+
+ return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+
+ Returns:
+ ndarray: The cutout image.
+ """
+
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+
+ return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+
+
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+
+
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/annotator/uniformer/mmcv/image/io.py b/annotator/uniformer/mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3fa2e8cc06b1a7b0b69de6406980b15d61a1e5d
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/io.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
+
+from annotator.uniformer.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
+
+try:
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+ """Select a backend for image decoding.
+
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+
+
+def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
+ """Read an image.
+
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ check_file_exist(img_or_path,
+ f'img file does not exist: {img_or_path}')
+ if backend == 'turbojpeg':
+ with open(img_or_path, 'rb') as in_file:
+ img = jpeg.decode(in_file.read(),
+ _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ img = Image.open(img_or_path)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ img = tifffile.imread(img_or_path)
+ return img
+ else:
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imread(img_or_path, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
+ global imread_backend specified by ``mmcv.use_backend()`` will be
+ used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ buff = io.BytesIO(content)
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = osp.abspath(osp.dirname(file_path))
+ mkdir_or_exist(dir_name)
+ return cv2.imwrite(file_path, img, params)
diff --git a/annotator/uniformer/mmcv/image/misc.py b/annotator/uniformer/mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e61f05e3b05e4c7b40de4eb6c8eb100e6da41d0
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/misc.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import annotator.uniformer.mmcv as mmcv
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+ """Convert tensor to 3-channel images.
+
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W).
+ mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
+ std (tuple[float], optional): Standard deviation of images.
+ Defaults to (1, 1, 1).
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ Defaults to True.
+
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ assert len(mean) == 3
+ assert len(std) == 3
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = mmcv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/annotator/uniformer/mmcv/image/photometric.py b/annotator/uniformer/mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/photometric.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+
+
+def iminvert(img):
+ """Invert (negate) an image.
+
+ Args:
+ img (ndarray): Image to be inverted.
+
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+
+
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+
+
+def imequalize(img):
+ """Equalize the image histogram.
+
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+
+ Args:
+ img (ndarray): Image to be equalized.
+
+ Returns:
+ ndarray: The equalized image.
+ """
+
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+
+ Returns:
+ ndarray: The sharpened image.
+ """
+
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+
+
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
diff --git a/annotator/uniformer/mmcv/model_zoo/deprecated.json b/annotator/uniformer/mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/annotator/uniformer/mmcv/model_zoo/mmcls.json b/annotator/uniformer/mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/mmcls.json
@@ -0,0 +1,31 @@
+{
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
+}
diff --git a/annotator/uniformer/mmcv/model_zoo/open_mmlab.json b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/annotator/uniformer/mmcv/ops/__init__.py b/annotator/uniformer/mmcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .contour_expand import contour_expand
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+ sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+ rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+ points_in_boxes_part)
+from .points_sampler import PointsSampler
+from .psa_mask import PSAMask
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+
+__all__ = [
+ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+ 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+ 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+ 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+ 'get_compiler_version', 'get_compiling_cuda_version',
+ 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+ 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+ 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+ 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+ 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+ 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+ 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+ 'border_align', 'gather_points', 'furthest_point_sample',
+ 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+ 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
+ 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
+ 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
+]
diff --git a/annotator/uniformer/mmcv/ops/assign_score_withk.py b/annotator/uniformer/mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/assign_score_withk.py
@@ -0,0 +1,123 @@
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+
+
+class AssignScoreWithK(Function):
+ r"""Perform weighted sum to generate output features according to scores.
+ Modified from `PAConv `_.
+
+ This is a memory-efficient CUDA implementation of assign_scores operation,
+ which first transform all point features with weight bank, then assemble
+ neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+
+ See the `paper `_ appendix Sec. D for
+ more detailed descriptions.
+
+ Note:
+ This implementation assumes using ``neighbor`` kernel input, which is
+ (point_features - center_features, point_features).
+ See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+ pointnet2/paconv.py#L128 for more details.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ scores,
+ point_features,
+ center_features,
+ knn_idx,
+ aggregate='sum'):
+ """
+ Args:
+ scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+ aggregate weight matrices in the weight bank.
+ ``npoint`` is the number of sampled centers.
+ ``K`` is the number of queried neighbors.
+ ``M`` is the number of weight matrices in the weight bank.
+ point_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed point features to be aggregated.
+ center_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed center features to be aggregated.
+ knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+ We assume the first idx in each row is the idx of the center.
+ aggregate (str, optional): Aggregation method.
+ Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+
+ Returns:
+ torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+ """
+ agg = {'sum': 0, 'avg': 1, 'max': 2}
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ output = point_features.new_zeros((B, out_dim, npoint, K))
+ ext_module.assign_score_withk_forward(
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ output,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg[aggregate])
+
+ ctx.save_for_backward(output, point_features, center_features, scores,
+ knn_idx)
+ ctx.agg = agg[aggregate]
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ """
+ Args:
+ grad_out (torch.Tensor): (B, out_dim, npoint, K)
+
+ Returns:
+ grad_scores (torch.Tensor): (B, npoint, K, M)
+ grad_point_features (torch.Tensor): (B, N, M, out_dim)
+ grad_center_features (torch.Tensor): (B, N, M, out_dim)
+ """
+ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+
+ agg = ctx.agg
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ grad_point_features = point_features.new_zeros(point_features.shape)
+ grad_center_features = center_features.new_zeros(center_features.shape)
+ grad_scores = scores.new_zeros(scores.shape)
+
+ ext_module.assign_score_withk_backward(
+ grad_out.contiguous(),
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ grad_point_features,
+ grad_center_features,
+ grad_scores,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg)
+
+ return grad_scores, grad_point_features, \
+ grad_center_features, None, None
+
+
+assign_score_withk = AssignScoreWithK.apply
diff --git a/annotator/uniformer/mmcv/ops/ball_query.py b/annotator/uniformer/mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/ball_query.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+
+
+class BallQuery(Function):
+ """Find nearby points in spherical space."""
+
+ @staticmethod
+ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+ xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ min_radius (float): minimum radius of the balls.
+ max_radius (float): maximum radius of the balls.
+ sample_num (int): maximum number of features in the balls.
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
+
+ Returns:
+ Tensor: (B, npoint, nsample) tensor with the indices of
+ the features that form the query balls.
+ """
+ assert center_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+ assert min_radius < max_radius
+
+ B, N, _ = xyz.size()
+ npoint = center_xyz.size(1)
+ idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+
+ ext_module.ball_query_forward(
+ center_xyz,
+ xyz,
+ idx,
+ b=B,
+ n=N,
+ m=npoint,
+ min_radius=min_radius,
+ max_radius=max_radius,
+ nsample=sample_num)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
diff --git a/annotator/uniformer/mmcv/ops/bbox.py b/annotator/uniformer/mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/bbox.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
+ """Calculate overlap between two set of bboxes.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (Tensor): shape (m, 4) in format or empty.
+ bboxes2 (Tensor): shape (n, 4) in format or empty.
+ If aligned is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> bbox_overlaps(bboxes1, bboxes2)
+ tensor([[0.5000, 0.0000, 0.0000],
+ [0.0000, 0.0000, 1.0000],
+ [0.0000, 0.0000, 0.0000]])
+
+ Example:
+ >>> empty = torch.FloatTensor([])
+ >>> nonempty = torch.FloatTensor([
+ >>> [0, 0, 10, 9],
+ >>> ])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ mode_dict = {'iou': 0, 'iof': 1}
+ assert mode in mode_dict.keys()
+ mode_flag = mode_dict[mode]
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+ assert offset == 1 or offset == 0
+
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows, cols))
+ ext_module.bbox_overlaps(
+ bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
+ return ious
diff --git a/annotator/uniformer/mmcv/ops/border_align.py b/annotator/uniformer/mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/border_align.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+ @staticmethod
+ def forward(ctx, input, boxes, pool_size):
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+ 1. uniformly samples `pool_size`+1 positions on this line, involving \
+ the start and end points.
+ 2. the corresponding features on these points are computed by \
+ bilinear interpolation.
+ 3. max pooling over all the `pool_size`+1 positions are used for \
+ computing pooled feature.
+
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+
+ """
+
+ def __init__(self, pool_size):
+ super(BorderAlign, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, input, boxes):
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+ Returns:
+ Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/box_iou_rotated.py b/annotator/uniformer/mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+
+
+def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
+ """Return intersection-over-union (Jaccard index) of boxes.
+
+ Both sets of boxes are expected to be in
+ (x_center, y_center, width, height, angle) format.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Arguments:
+ boxes1 (Tensor): rotated bboxes 1. \
+ It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ boxes2 (Tensor): rotated bboxes 2. \
+ It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (N, M) if aligned == False else shape (N,)
+ """
+ assert mode in ['iou', 'iof']
+ mode_dict = {'iou': 0, 'iof': 1}
+ mode_flag = mode_dict[mode]
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows * cols))
+ bboxes1 = bboxes1.contiguous()
+ bboxes2 = bboxes2.contiguous()
+ ext_module.box_iou_rotated(
+ bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+ if not aligned:
+ ious = ious.view(rows, cols)
+ return ious
diff --git a/annotator/uniformer/mmcv/ops/carafe.py b/annotator/uniformer/mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/carafe.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+ 'carafe_backward'
+])
+
+
+class CARAFENaiveFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFENaive',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ ext_module.carafe_naive_forward(
+ features,
+ masks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ grad_input = torch.zeros_like(features)
+ grad_masks = torch.zeros_like(masks)
+ ext_module.carafe_naive_backward(
+ grad_output.contiguous(),
+ features,
+ masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFENaive, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe_naive(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFE',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ routput = features.new_zeros(output.size(), requires_grad=False)
+ rfeatures = features.new_zeros(features.size(), requires_grad=False)
+ rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+ ext_module.carafe_forward(
+ features,
+ masks,
+ rfeatures,
+ routput,
+ rmasks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks, rfeatures)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks, rfeatures = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input = torch.zeros_like(features, requires_grad=False)
+ rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+ grad_input = torch.zeros_like(features, requires_grad=False)
+ grad_masks = torch.zeros_like(masks, requires_grad=False)
+ ext_module.carafe_backward(
+ grad_output.contiguous(),
+ rfeatures,
+ masks,
+ rgrad_output,
+ rgrad_input_hs,
+ rgrad_input,
+ rgrad_masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+ """ CARAFE: Content-Aware ReAssembly of FEatures
+
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ kernel_size (int): reassemble kernel size
+ group_size (int): reassemble group size
+ scale_factor (int): upsample ratio
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFE, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+@UPSAMPLE_LAYERS.register_module(name='carafe')
+class CARAFEPack(nn.Module):
+ """A unified package of CARAFE upsampler that contains: 1) channel
+ compressor 2) content encoder 3) CARAFE op.
+
+ Official implementation of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ channels (int): input feature channels
+ scale_factor (int): upsample ratio
+ up_kernel (int): kernel size of CARAFE op
+ up_group (int): group size of CARAFE op
+ encoder_kernel (int): kernel size of content encoder
+ encoder_dilation (int): dilation of content encoder
+ compressed_channels (int): output channels of channels compressor
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self,
+ channels,
+ scale_factor,
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1,
+ compressed_channels=64):
+ super(CARAFEPack, self).__init__()
+ self.channels = channels
+ self.scale_factor = scale_factor
+ self.up_kernel = up_kernel
+ self.up_group = up_group
+ self.encoder_kernel = encoder_kernel
+ self.encoder_dilation = encoder_dilation
+ self.compressed_channels = compressed_channels
+ self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+ 1)
+ self.content_encoder = nn.Conv2d(
+ self.compressed_channels,
+ self.up_kernel * self.up_kernel * self.up_group *
+ self.scale_factor * self.scale_factor,
+ self.encoder_kernel,
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+ dilation=self.encoder_dilation,
+ groups=1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ normal_init(self.content_encoder, std=0.001)
+
+ def kernel_normalizer(self, mask):
+ mask = F.pixel_shuffle(mask, self.scale_factor)
+ n, mask_c, h, w = mask.size()
+ # use float division explicitly,
+ # to void inconsistency while exporting to onnx
+ mask_channel = int(mask_c / float(self.up_kernel**2))
+ mask = mask.view(n, mask_channel, -1, h, w)
+
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+ mask = mask.view(n, mask_c, h, w).contiguous()
+
+ return mask
+
+ def feature_reassemble(self, x, mask):
+ x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+ return x
+
+ def forward(self, x):
+ compressed_x = self.channel_compressor(x)
+ mask = self.content_encoder(compressed_x)
+ mask = self.kernel_normalizer(mask)
+
+ x = self.feature_reassemble(x, mask)
+ return x
diff --git a/annotator/uniformer/mmcv/ops/cc_attention.py b/annotator/uniformer/mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9207aa95e6730bd9b3362dee612059a5f0ce1c5e
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/cc_attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.cnn import PLUGIN_LAYERS, Scale
+
+
+def NEG_INF_DIAG(n, device):
+ """Returns a diagonal matrix of size [n, n].
+
+ The diagonal are all "-inf". This is for avoiding calculating the
+ overlapped element in the Criss-Cross twice.
+ """
+ return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+
+
+@PLUGIN_LAYERS.register_module()
+class CrissCrossAttention(nn.Module):
+ """Criss-Cross Attention Module.
+
+ .. note::
+ Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+ to a pure PyTorch and equivalent implementation. For more
+ details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+
+ Speed comparison for one forward pass
+
+ - Input size: [2,512,97,97]
+ - Device: 1 NVIDIA GeForce RTX 2080 Ti
+
+ +-----------------------+---------------+------------+---------------+
+ | |PyTorch version|CUDA version|Relative speed |
+ +=======================+===============+============+===============+
+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+ self.gamma = Scale(0.)
+ self.in_channels = in_channels
+
+ def forward(self, x):
+ """forward function of Criss-Cross Attention.
+
+ Args:
+ x (Tensor): Input feature. \
+ shape (batch_size, in_channels, height, width)
+ Returns:
+ Tensor: Output of the layer, with shape of \
+ (batch_size, in_channels, height, width)
+ """
+ B, C, H, W = x.size()
+ query = self.query_conv(x)
+ key = self.key_conv(x)
+ value = self.value_conv(x)
+ energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+ H, query.device)
+ energy_H = energy_H.transpose(1, 2)
+ energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+ attn = F.softmax(
+ torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
+ out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+ out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+
+ out = self.gamma(out) + x
+ out = out.contiguous()
+
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/contour_expand.py b/annotator/uniformer/mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/contour_expand.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+
+
+def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
+ kernel_num):
+ """Expand kernel contours so that foreground pixels are assigned into
+ instances.
+
+ Arguments:
+ kernel_mask (np.array or Tensor): The instance kernel mask with
+ size hxw.
+ internal_kernel_label (np.array or Tensor): The instance internal
+ kernel label with size hxw.
+ min_kernel_area (int): The minimum kernel area.
+ kernel_num (int): The instance kernel number.
+
+ Returns:
+ label (list): The instance index map with size hxw.
+ """
+ assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+ assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(min_kernel_area, int)
+ assert isinstance(kernel_num, int)
+
+ if isinstance(kernel_mask, np.ndarray):
+ kernel_mask = torch.from_numpy(kernel_mask)
+ if isinstance(internal_kernel_label, np.ndarray):
+ internal_kernel_label = torch.from_numpy(internal_kernel_label)
+
+ if torch.__version__ == 'parrots':
+ if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+ label = []
+ else:
+ label = ext_module.contour_expand(
+ kernel_mask,
+ internal_kernel_label,
+ min_kernel_area=min_kernel_area,
+ kernel_num=kernel_num)
+ label = label.tolist()
+ else:
+ label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+ min_kernel_area, kernel_num)
+ return label
diff --git a/annotator/uniformer/mmcv/ops/corner_pool.py b/annotator/uniformer/mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/corner_pool.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
+ 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
+ 'right_pool_forward', 'right_pool_backward'
+])
+
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+
+
+class TopPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.top_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.top_pool_backward(input, grad_output)
+ return output
+
+
+class BottomPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.bottom_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.bottom_pool_backward(input, grad_output)
+ return output
+
+
+class LeftPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.left_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.left_pool_backward(input, grad_output)
+ return output
+
+
+class RightPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.right_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.right_pool_backward(input, grad_output)
+ return output
+
+
+class CornerPool(nn.Module):
+ """Corner Pooling.
+
+ Corner Pooling is a new type of pooling layer that helps a
+ convolutional network better localize corners of bounding boxes.
+
+ Please refer to https://arxiv.org/abs/1808.01244 for more details.
+ Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+ Args:
+ mode(str): Pooling orientation for the pooling layer
+
+ - 'bottom': Bottom Pooling
+ - 'left': Left Pooling
+ - 'right': Right Pooling
+ - 'top': Top Pooling
+
+ Returns:
+ Feature map after pooling.
+ """
+
+ pool_functions = {
+ 'bottom': BottomPoolFunction,
+ 'left': LeftPoolFunction,
+ 'right': RightPoolFunction,
+ 'top': TopPoolFunction,
+ }
+
+ cummax_dim_flip = {
+ 'bottom': (2, False),
+ 'left': (3, True),
+ 'right': (3, False),
+ 'top': (2, True),
+ }
+
+ def __init__(self, mode):
+ super(CornerPool, self).__init__()
+ assert mode in self.pool_functions
+ self.mode = mode
+ self.corner_pool = self.pool_functions[mode]
+
+ def forward(self, x):
+ if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+ if torch.onnx.is_in_onnx_export():
+ assert torch.__version__ >= '1.7.0', \
+ 'When `cummax` serves as an intermediate component whose '\
+ 'outputs is used as inputs for another modules, it\'s '\
+ 'expected that pytorch version must be >= 1.7.0, '\
+ 'otherwise Error appears like: `RuntimeError: tuple '\
+ 'appears in op that does not forward tuples, unsupported '\
+ 'kind: prim::PythonOp`.'
+
+ dim, flip = self.cummax_dim_flip[self.mode]
+ if flip:
+ x = x.flip(dim)
+ pool_tensor, _ = torch.cummax(x, dim=dim)
+ if flip:
+ pool_tensor = pool_tensor.flip(dim)
+ return pool_tensor
+ else:
+ return self.corner_pool.apply(x)
diff --git a/annotator/uniformer/mmcv/ops/correlation.py b/annotator/uniformer/mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/correlation.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['correlation_forward', 'correlation_backward'])
+
+
+class CorrelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input1,
+ input2,
+ kernel_size=1,
+ max_displacement=1,
+ stride=1,
+ padding=1,
+ dilation=1,
+ dilation_patch=1):
+
+ ctx.save_for_backward(input1, input2)
+
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
+ patch_size = max_displacement * 2 + 1
+ ctx.patch_size = patch_size
+ dH, dW = ctx.stride = _pair(stride)
+ padH, padW = ctx.padding = _pair(padding)
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+ dilation_patch)
+
+ output_size = CorrelationFunction._output_size(ctx, input1)
+
+ output = input1.new_zeros(output_size)
+
+ ext_module.correlation_forward(
+ input1,
+ input2,
+ output,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
+
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
+ dH, dW = ctx.stride
+ grad_input1 = torch.zeros_like(input1)
+ grad_input2 = torch.zeros_like(input2)
+
+ ext_module.correlation_backward(
+ grad_output,
+ input1,
+ input2,
+ grad_input1,
+ grad_input2,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return grad_input1, grad_input2, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input1):
+ iH, iW = input1.size(2), input1.size(3)
+ batch_size = input1.size(0)
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ dH, dW = ctx.stride
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilatedKH = (kH - 1) * dilationH + 1
+ dilatedKW = (kW - 1) * dilationW + 1
+
+ oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+ oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+
+ output_size = (batch_size, patch_size, patch_size, oH, oW)
+ return output_size
+
+
+class Correlation(nn.Module):
+ r"""Correlation operator
+
+ This correlation operator works for optical flow correlation computation.
+
+ There are two batched tensors with shape :math:`(N, C, H, W)`,
+ and the correlation output's shape is :math:`(N, max\_displacement \times
+ 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+
+ where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding -
+ dilation \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation
+ \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+ window convolution between input1 and shifted input2,
+
+ .. math::
+ Corr(N_i, dx, dy) =
+ \sum_{c=0}^{C-1}
+ input1(N_i, c) \star
+ \mathcal{S}(input2(N_i, c), dy, dx)
+
+ where :math:`\star` is the valid 2d sliding window convolution operator,
+ and :math:`\mathcal{S}` means shifting the input features (auto-complete
+ zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+ [-max\_displacement \times dilation\_patch, max\_displacement \times
+ dilation\_patch]`.
+
+ Args:
+ kernel_size (int): The size of sliding window i.e. local neighborhood
+ representing the center points and involved in correlation
+ computation. Defaults to 1.
+ max_displacement (int): The radius for computing correlation volume,
+ but the actual working space can be dilated by dilation_patch.
+ Defaults to 1.
+ stride (int): The stride of the sliding blocks in the input spatial
+ dimensions. Defaults to 1.
+ padding (int): Zero padding added to all four sides of the input1.
+ Defaults to 0.
+ dilation (int): The spacing of local neighborhood that will involved
+ in correlation. Defaults to 1.
+ dilation_patch (int): The spacing between position need to compute
+ correlation. Defaults to 1.
+ """
+
+ def __init__(self,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> None:
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.dilation_patch = dilation_patch
+
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+ return CorrelationFunction.apply(input1, input2, self.kernel_size,
+ self.max_displacement, self.stride,
+ self.padding, self.dilation,
+ self.dilation_patch)
+
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(kernel_size={self.kernel_size}, '
+ s += f'max_displacement={self.max_displacement}, '
+ s += f'stride={self.stride}, '
+ s += f'padding={self.padding}, '
+ s += f'dilation={self.dilation}, '
+ s += f'dilation_patch={self.dilation_patch})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/deform_conv.py b/annotator/uniformer/mmcv/ops/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f8c75ee774823eea334e3b3732af6a18f55038
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deform_conv.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'deform_conv_forward', 'deform_conv_backward_input',
+ 'deform_conv_backward_parameters'
+])
+
+
+class DeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g,
+ input,
+ offset,
+ weight,
+ stride,
+ padding,
+ dilation,
+ groups,
+ deform_groups,
+ bias=False,
+ im2col_step=32):
+ return g.op(
+ 'mmcv::MMCVDeformConv2d',
+ input,
+ offset,
+ weight,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups,
+ bias_i=bias,
+ im2col_step_i=im2col_step)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=False,
+ im2col_step=32):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ assert bias is False, 'Only support bias is False.'
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.im2col_step = im2col_step
+
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConv2dFunction._output_size(ctx, input, weight))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ ext_module.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) % cur_im2col_step
+ ) == 0, 'batch size must be divisible by im2col_step'
+
+ grad_output = grad_output.contiguous()
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ ext_module.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ ext_module.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ scale=1,
+ im2col_step=cur_im2col_step)
+
+ return grad_input, grad_offset, grad_weight, \
+ None, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+deform_conv2d = DeformConv2dFunction.apply
+
+
+class DeformConv2d(nn.Module):
+ r"""Deformable 2D convolution.
+
+ Applies a deformable 2D convolution over an input signal composed of
+ several input planes. DeformConv2d was described in the paper
+ `Deformable Convolutional Networks
+ `_
+
+ Note:
+ The argument ``im2col_step`` was added in version 1.3.17, which means
+ number of samples processed by the ``im2col_cuda_kernel`` per call.
+ It enables users to define ``batch_size`` and ``im2col_step`` more
+ flexibly and solved `issue mmcv#1440
+ `_.
+
+ Args:
+ in_channels (int): Number of channels in the input image.
+ out_channels (int): Number of channels produced by the convolution.
+ kernel_size(int, tuple): Size of the convolving kernel.
+ stride(int, tuple): Stride of the convolution. Default: 1.
+ padding (int or tuple): Zero-padding added to both sides of the input.
+ Default: 0.
+ dilation (int or tuple): Spacing between kernel elements. Default: 1.
+ groups (int): Number of blocked connections from input.
+ channels to output channels. Default: 1.
+ deform_groups (int): Number of deformable group partitions.
+ bias (bool): If True, adds a learnable bias to the output.
+ Default: False.
+ im2col_step (int): Number of samples processed by im2col_cuda_kernel
+ per call. It will work when ``batch_size`` > ``im2col_step``, but
+ ``batch_size`` must be divisible by ``im2col_step``. Default: 32.
+ `New in version 1.3.17.`
+ """
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='DeformConv2d')
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]],
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ groups: int = 1,
+ deform_groups: int = 1,
+ bias: bool = False,
+ im2col_step: int = 32) -> None:
+ super(DeformConv2d, self).__init__()
+
+ assert not bias, \
+ f'bias={bias} is not supported in DeformConv2d.'
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} cannot be divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} cannot be divisible by groups \
+ {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ self.im2col_step = im2col_step
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ # only weight, no bias
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # switch the initialization of `self.weight` to the standard kaiming
+ # method described in `Delving deep into rectifiers: Surpassing
+ # human-level performance on ImageNet classification` - He, K. et al.
+ # (2015), using a uniform distribution
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
+
+ def forward(self, x: Tensor, offset: Tensor) -> Tensor:
+ """Deformable Convolutional forward function.
+
+ Args:
+ x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
+ offset (Tensor): Offset for deformable convolution, shape
+ (B, deform_groups*kernel_size[0]*kernel_size[1]*2,
+ H_out, W_out), H_out, W_out are equal to the output's.
+
+ An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Returns:
+ Tensor: Output of the layer.
+ """
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
+ self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
+ offset = offset.contiguous()
+ out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+ pad_w].contiguous()
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels},\n'
+ s += f'out_channels={self.out_channels},\n'
+ s += f'kernel_size={self.kernel_size},\n'
+ s += f'stride={self.stride},\n'
+ s += f'padding={self.padding},\n'
+ s += f'dilation={self.dilation},\n'
+ s += f'groups={self.groups},\n'
+ s += f'deform_groups={self.deform_groups},\n'
+ # bias is not supported in DeformConv2d.
+ s += 'bias=False)'
+ return s
+
+
+@CONV_LAYERS.register_module('DCN')
+class DeformConv2dPack(DeformConv2d):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, DeformConvPack loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/annotator/uniformer/mmcv/ops/deform_roi_pool.py b/annotator/uniformer/mmcv/ops/deform_roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deform_roi_pool.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward'])
+
+
+class DeformRoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, offset, output_size, spatial_scale,
+ sampling_ratio, gamma):
+ return g.op(
+ 'mmcv::MMCVDeformRoIPool',
+ input,
+ rois,
+ offset,
+ pooled_height_i=output_size[0],
+ pooled_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_f=sampling_ratio,
+ gamma_f=gamma)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ offset,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ if offset is None:
+ offset = input.new_zeros(0)
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = float(spatial_scale)
+ ctx.sampling_ratio = int(sampling_ratio)
+ ctx.gamma = float(gamma)
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+
+ ext_module.deform_roi_pool_forward(
+ input,
+ rois,
+ offset,
+ output,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+
+ ctx.save_for_backward(input, rois, offset)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, rois, offset = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(input.shape)
+ grad_offset = grad_output.new_zeros(offset.shape)
+
+ ext_module.deform_roi_pool_backward(
+ grad_output,
+ input,
+ rois,
+ offset,
+ grad_input,
+ grad_offset,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+ if grad_offset.numel() == 0:
+ grad_offset = None
+ return grad_input, None, grad_offset, None, None, None, None
+
+
+deform_roi_pool = DeformRoIPoolFunction.apply
+
+
+class DeformRoIPool(nn.Module):
+
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPool, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.gamma = float(gamma)
+
+ def forward(self, input, rois, offset=None):
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class DeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
+ sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class ModulatedDeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(ModulatedDeformRoIPoolPack,
+ self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ self.mask_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 1),
+ nn.Sigmoid())
+ self.mask_fc[2].weight.data.zero_()
+ self.mask_fc[2].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ mask = self.mask_fc(x.view(rois_num, -1))
+ mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1])
+ d = deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ return d * mask
diff --git a/annotator/uniformer/mmcv/ops/deprecated_wrappers.py b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file is for backward compatibility.
+# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
+import warnings
+
+from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
+
+
+class Conv2d_deprecated(Conv2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class ConvTranspose2d_deprecated(ConvTranspose2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
+ 'deprecated in the future. Please import them from "mmcv.cnn" '
+ 'instead')
+
+
+class MaxPool2d_deprecated(MaxPool2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class Linear_deprecated(Linear):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
diff --git a/annotator/uniformer/mmcv/ops/focal_loss.py b/annotator/uniformer/mmcv/ops/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/focal_loss.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
+ 'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
+])
+
+
+class SigmoidFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSigmoidFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ output = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_forward(
+ input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input, target, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, target, weight = ctx.saved_tensors
+
+ grad_input = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_backward(
+ input,
+ target,
+ weight,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input.size(0)
+ return grad_input, None, None, None, None, None
+
+
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
+
+
+class SigmoidFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SigmoidFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
+
+
+class SoftmaxFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSoftmaxFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ channel_stats, _ = torch.max(input, dim=1)
+ input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
+ input_softmax.exp_()
+
+ channel_stats = input_softmax.sum(dim=1)
+ input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
+
+ output = input.new_zeros(input.size(0))
+ ext_module.softmax_focal_loss_forward(
+ input_softmax,
+ target,
+ weight,
+ output,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input_softmax, target, weight)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_softmax, target, weight = ctx.saved_tensors
+ buff = input_softmax.new_zeros(input_softmax.size(0))
+ grad_input = input_softmax.new_zeros(input_softmax.size())
+
+ ext_module.softmax_focal_loss_backward(
+ input_softmax,
+ target,
+ weight,
+ buff,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input_softmax.size(0)
+ return grad_input, None, None, None, None, None
+
+
+softmax_focal_loss = SoftmaxFocalLossFunction.apply
+
+
+class SoftmaxFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SoftmaxFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return softmax_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/furthest_point_sample.py b/annotator/uniformer/mmcv/ops/furthest_point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/furthest_point_sample.py
@@ -0,0 +1,83 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'furthest_point_sampling_forward',
+ 'furthest_point_sampling_with_dist_forward'
+])
+
+
+class FurthestPointSampling(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_xyz: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) where N > num_points.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_xyz.is_contiguous()
+
+ B, N = points_xyz.size()[:2]
+ output = torch.cuda.IntTensor(B, num_points)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ ext_module.furthest_point_sampling_forward(
+ points_xyz,
+ temp,
+ output,
+ b=B,
+ n=N,
+ m=num_points,
+ )
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+class FurthestPointSamplingWithDist(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_dist: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_dist (Tensor): (B, N, N) Distance between each point pair.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_dist.is_contiguous()
+
+ B, N, _ = points_dist.size()
+ output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
+ temp = points_dist.new_zeros([B, N]).fill_(1e10)
+
+ ext_module.furthest_point_sampling_with_dist_forward(
+ points_dist, temp, output, b=B, n=N, m=num_points)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
diff --git a/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
@@ -0,0 +1,268 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
+
+
+class FusedBiasLeakyReLUFunctionBackward(Function):
+ """Calculate second order deviation.
+
+ This function is to compute the second order deviation for the fused leaky
+ relu operation.
+ """
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = ext_module.fused_bias_leakyrelu(
+ grad_output,
+ empty,
+ out,
+ act=3,
+ grad=1,
+ alpha=negative_slope,
+ scale=scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+
+ # The second order deviation, in fact, contains two parts, while the
+ # the first part is zero. Thus, we direct consider the second part
+ # which is similar with the first order deviation in implementation.
+ gradgrad_out = ext_module.fused_bias_leakyrelu(
+ gradgrad_input,
+ gradgrad_bias.to(out.dtype),
+ out,
+ act=3,
+ grad=1,
+ alpha=ctx.negative_slope,
+ scale=ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedBiasLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+
+ out = ext_module.fused_bias_leakyrelu(
+ input,
+ bias,
+ empty,
+ act=3,
+ grad=0,
+ alpha=negative_slope,
+ scale=scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedBiasLeakyReLU(nn.Module):
+ """Fused bias leaky ReLU.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ TODO: Implement the CPU version.
+
+ Args:
+ channel (int): The channel number of the feature map.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+ """
+
+ def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
+ super(FusedBiasLeakyReLU, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
+ self.scale)
+
+
+def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
+ """Fused bias leaky ReLU function.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ Args:
+ input (torch.Tensor): Input feature map.
+ bias (nn.Parameter): The bias from convolution operation.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+
+ Returns:
+ torch.Tensor: Feature map after non-linear activation.
+ """
+
+ if not input.is_cuda:
+ return bias_leakyrelu_ref(input, bias, negative_slope, scale)
+
+ return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
+ negative_slope, scale)
+
+
+def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
+
+ if bias is not None:
+ assert bias.ndim == 1
+ assert bias.shape[0] == x.shape[1]
+ x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
+
+ x = F.leaky_relu(x, negative_slope)
+ if scale != 1:
+ x = x * scale
+
+ return x
diff --git a/annotator/uniformer/mmcv/ops/gather_points.py b/annotator/uniformer/mmcv/ops/gather_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/gather_points.py
@@ -0,0 +1,57 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['gather_points_forward', 'gather_points_backward'])
+
+
+class GatherPoints(Function):
+ """Gather points with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) features to gather.
+ indices (Tensor): (B, M) where M is the number of points.
+
+ Returns:
+ Tensor: (B, C, M) where M is the number of points.
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+
+ B, npoint = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+
+ ext_module.gather_points_forward(
+ features, indices, output, b=B, c=C, n=N, npoints=npoint)
+
+ ctx.for_backwards = (indices, C, N)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(indices)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.gather_points_backward(
+ grad_out_data,
+ idx,
+ grad_features.data,
+ b=B,
+ c=C,
+ n=N,
+ npoints=npoint)
+ return grad_features, None
+
+
+gather_points = GatherPoints.apply
diff --git a/annotator/uniformer/mmcv/ops/group_points.py b/annotator/uniformer/mmcv/ops/group_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/group_points.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+from .ball_query import ball_query
+from .knn import knn
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['group_points_forward', 'group_points_backward'])
+
+
+class QueryAndGroup(nn.Module):
+ """Groups points with a ball query of radius.
+
+ Args:
+ max_radius (float): The maximum radius of the balls.
+ If None is given, we will use kNN sampling instead of ball query.
+ sample_num (int): Maximum number of features to gather in the ball.
+ min_radius (float, optional): The minimum radius of the balls.
+ Default: 0.
+ use_xyz (bool, optional): Whether to use xyz.
+ Default: True.
+ return_grouped_xyz (bool, optional): Whether to return grouped xyz.
+ Default: False.
+ normalize_xyz (bool, optional): Whether to normalize xyz.
+ Default: False.
+ uniform_sample (bool, optional): Whether to sample uniformly.
+ Default: False
+ return_unique_cnt (bool, optional): Whether to return the count of
+ unique samples. Default: False.
+ return_grouped_idx (bool, optional): Whether to return grouped idx.
+ Default: False.
+ """
+
+ def __init__(self,
+ max_radius,
+ sample_num,
+ min_radius=0,
+ use_xyz=True,
+ return_grouped_xyz=False,
+ normalize_xyz=False,
+ uniform_sample=False,
+ return_unique_cnt=False,
+ return_grouped_idx=False):
+ super().__init__()
+ self.max_radius = max_radius
+ self.min_radius = min_radius
+ self.sample_num = sample_num
+ self.use_xyz = use_xyz
+ self.return_grouped_xyz = return_grouped_xyz
+ self.normalize_xyz = normalize_xyz
+ self.uniform_sample = uniform_sample
+ self.return_unique_cnt = return_unique_cnt
+ self.return_grouped_idx = return_grouped_idx
+ if self.return_unique_cnt:
+ assert self.uniform_sample, \
+ 'uniform_sample should be True when ' \
+ 'returning the count of unique samples'
+ if self.max_radius is None:
+ assert not self.normalize_xyz, \
+ 'can not normalize grouped xyz when max_radius is None'
+
+ def forward(self, points_xyz, center_xyz, features=None):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
+ """
+ # if self.max_radius is None, we will perform kNN instead of ball query
+ # idx is of shape [B, npoint, sample_num]
+ if self.max_radius is None:
+ idx = knn(self.sample_num, points_xyz, center_xyz, False)
+ idx = idx.transpose(1, 2).contiguous()
+ else:
+ idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
+ points_xyz, center_xyz)
+
+ if self.uniform_sample:
+ unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
+ for i_batch in range(idx.shape[0]):
+ for i_region in range(idx.shape[1]):
+ unique_ind = torch.unique(idx[i_batch, i_region, :])
+ num_unique = unique_ind.shape[0]
+ unique_cnt[i_batch, i_region] = num_unique
+ sample_ind = torch.randint(
+ 0,
+ num_unique, (self.sample_num - num_unique, ),
+ dtype=torch.long)
+ all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
+ idx[i_batch, i_region, :] = all_ind
+
+ xyz_trans = points_xyz.transpose(1, 2).contiguous()
+ # (B, 3, npoint, sample_num)
+ grouped_xyz = grouping_operation(xyz_trans, idx)
+ grouped_xyz_diff = grouped_xyz - \
+ center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
+ if self.normalize_xyz:
+ grouped_xyz_diff /= self.max_radius
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ # (B, C + 3, npoint, sample_num)
+ new_features = torch.cat([grouped_xyz_diff, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ assert (self.use_xyz
+ ), 'Cannot have not features and not use xyz as a feature!'
+ new_features = grouped_xyz_diff
+
+ ret = [new_features]
+ if self.return_grouped_xyz:
+ ret.append(grouped_xyz)
+ if self.return_unique_cnt:
+ ret.append(unique_cnt)
+ if self.return_grouped_idx:
+ ret.append(idx)
+ if len(ret) == 1:
+ return ret[0]
+ else:
+ return tuple(ret)
+
+
+class GroupAll(nn.Module):
+ """Group xyz with feature.
+
+ Args:
+ use_xyz (bool): Whether to use xyz.
+ """
+
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self,
+ xyz: torch.Tensor,
+ new_xyz: torch.Tensor,
+ features: torch.Tensor = None):
+ """
+ Args:
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ new_xyz (Tensor): new xyz coordinates of the features.
+ features (Tensor): (B, C, N) features to group.
+
+ Returns:
+ Tensor: (B, C + 3, 1, N) Grouped feature.
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ # (B, 3 + C, 1, N)
+ new_features = torch.cat([grouped_xyz, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupingOperation(Function):
+ """Group feature with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) tensor of features to group.
+ indices (Tensor): (B, npoint, nsample) the indices of
+ features to group with.
+
+ Returns:
+ Tensor: (B, C, npoint, nsample) Grouped features.
+ """
+ features = features.contiguous()
+ indices = indices.contiguous()
+
+ B, nfeatures, nsample = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+ ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
+ indices, output)
+
+ ctx.for_backwards = (indices, N)
+ return output
+
+ @staticmethod
+ def backward(ctx,
+ grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
+ of the output from forward.
+
+ Returns:
+ Tensor: (B, C, N) gradient of the features.
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.group_points_backward(B, C, N, npoint, nsample,
+ grad_out_data, idx,
+ grad_features.data)
+ return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
diff --git a/annotator/uniformer/mmcv/ops/info.py b/annotator/uniformer/mmcv/ops/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/info.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+
+import torch
+
+if torch.__version__ == 'parrots':
+ import parrots
+
+ def get_compiler_version():
+ return 'GCC ' + parrots.version.compiler
+
+ def get_compiling_cuda_version():
+ return parrots.version.cuda
+else:
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext(
+ '_ext', ['get_compiler_version', 'get_compiling_cuda_version'])
+
+ def get_compiler_version():
+ return ext_module.get_compiler_version()
+
+ def get_compiling_cuda_version():
+ return ext_module.get_compiling_cuda_version()
+
+
+def get_onnxruntime_op_path():
+ wildcard = os.path.join(
+ os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
+ '_ext_ort.*.so')
+
+ paths = glob.glob(wildcard)
+ if len(paths) > 0:
+ return paths[0]
+ else:
+ return ''
diff --git a/annotator/uniformer/mmcv/ops/iou3d.py b/annotator/uniformer/mmcv/ops/iou3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/iou3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
+ 'iou3d_nms_normal_forward'
+])
+
+
+def boxes_iou_bev(boxes_a, boxes_b):
+ """Calculate boxes IoU in the Bird's Eye View.
+
+ Args:
+ boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
+ boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
+
+ Returns:
+ ans_iou (torch.Tensor): IoU result with shape (M, N).
+ """
+ ans_iou = boxes_a.new_zeros(
+ torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
+
+ ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
+ boxes_b.contiguous(), ans_iou)
+
+ return ans_iou
+
+
+def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
+ """NMS function GPU implementation (for BEV boxes). The overlap of two
+ boxes for IoU calculation is defined as the exact overlapping area of the
+ two boxes. In this function, one can also set ``pre_max_size`` and
+ ``post_max_size``.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with the shape of [N, 5]
+ ([x1, y1, x2, y2, ry]).
+ scores (torch.Tensor): Scores of boxes with the shape of [N].
+ thresh (float): Overlap threshold of NMS.
+ pre_max_size (int, optional): Max size of boxes before NMS.
+ Default: None.
+ post_max_size (int, optional): Max size of boxes after NMS.
+ Default: None.
+
+ Returns:
+ torch.Tensor: Indexes after NMS.
+ """
+ assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ if pre_max_size is not None:
+ order = order[:pre_max_size]
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
+ keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
+ if post_max_size is not None:
+ keep = keep[:post_max_size]
+ return keep
+
+
+def nms_normal_bev(boxes, scores, thresh):
+ """Normal NMS function GPU implementation (for BEV boxes). The overlap of
+ two boxes for IoU calculation is defined as the exact overlapping area of
+ the two boxes WITH their yaw angle set to 0.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with shape (N, 5).
+ scores (torch.Tensor): Scores of predicted boxes with shape (N).
+ thresh (float): Overlap threshold of NMS.
+
+ Returns:
+ torch.Tensor: Remaining indices with scores in descending order.
+ """
+ assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
+ return order[keep[:num_out].cuda(boxes.device)].contiguous()
diff --git a/annotator/uniformer/mmcv/ops/knn.py b/annotator/uniformer/mmcv/ops/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/knn.py
@@ -0,0 +1,77 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
+
+
+class KNN(Function):
+ r"""KNN (CUDA) based on heap data structure.
+ Modified from `PAConv `_.
+
+ Find k-nearest points.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ k: int,
+ xyz: torch.Tensor,
+ center_xyz: torch.Tensor = None,
+ transposed: bool = False) -> torch.Tensor:
+ """
+ Args:
+ k (int): number of nearest neighbors.
+ xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
+ xyz coordinates of the features.
+ center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
+ False, else (B, 3, npoint). centers of the knn query.
+ Default: None.
+ transposed (bool, optional): whether the input tensors are
+ transposed. Should not explicitly use this keyword when
+ calling knn (=KNN.apply), just add the fourth param.
+ Default: False.
+
+ Returns:
+ Tensor: (B, k, npoint) tensor with the indices of
+ the features that form k-nearest neighbours.
+ """
+ assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
+
+ if center_xyz is None:
+ center_xyz = xyz
+
+ if transposed:
+ xyz = xyz.transpose(2, 1).contiguous()
+ center_xyz = center_xyz.transpose(2, 1).contiguous()
+
+ assert xyz.is_contiguous() # [B, N, 3]
+ assert center_xyz.is_contiguous() # [B, npoint, 3]
+
+ center_xyz_device = center_xyz.get_device()
+ assert center_xyz_device == xyz.get_device(), \
+ 'center_xyz and xyz should be put on the same device'
+ if torch.cuda.current_device() != center_xyz_device:
+ torch.cuda.set_device(center_xyz_device)
+
+ B, npoint, _ = center_xyz.shape
+ N = xyz.shape[1]
+
+ idx = center_xyz.new_zeros((B, npoint, k)).int()
+ dist2 = center_xyz.new_zeros((B, npoint, k)).float()
+
+ ext_module.knn_forward(
+ xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
+ # idx shape to [B, k, npoint]
+ idx = idx.transpose(2, 1).contiguous()
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None
+
+
+knn = KNN.apply
diff --git a/annotator/uniformer/mmcv/ops/masked_conv.py b/annotator/uniformer/mmcv/ops/masked_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/masked_conv.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
+
+
+class MaskedConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, mask, weight, bias, padding, stride):
+ return g.op(
+ 'mmcv::MMCVMaskedConv2d',
+ features,
+ mask,
+ weight,
+ bias,
+ padding_i=padding,
+ stride_i=stride)
+
+ @staticmethod
+ def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
+ assert mask.dim() == 3 and mask.size(0) == 1
+ assert features.dim() == 4 and features.size(0) == 1
+ assert features.size()[2:] == mask.size()[1:]
+ pad_h, pad_w = _pair(padding)
+ stride_h, stride_w = _pair(stride)
+ if stride_h != 1 or stride_w != 1:
+ raise ValueError(
+ 'Stride could not only be 1 in masked_conv2d currently.')
+ out_channel, in_channel, kernel_h, kernel_w = weight.size()
+
+ batch_size = features.size(0)
+ out_h = int(
+ math.floor((features.size(2) + 2 * pad_h -
+ (kernel_h - 1) - 1) / stride_h + 1))
+ out_w = int(
+ math.floor((features.size(3) + 2 * pad_w -
+ (kernel_h - 1) - 1) / stride_w + 1))
+ mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
+ output = features.new_zeros(batch_size, out_channel, out_h, out_w)
+ if mask_inds.numel() > 0:
+ mask_h_idx = mask_inds[:, 0].contiguous()
+ mask_w_idx = mask_inds[:, 1].contiguous()
+ data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+ mask_inds.size(0))
+ ext_module.masked_im2col_forward(
+ features,
+ mask_h_idx,
+ mask_w_idx,
+ data_col,
+ kernel_h=kernel_h,
+ kernel_w=kernel_w,
+ pad_h=pad_h,
+ pad_w=pad_w)
+
+ masked_output = torch.addmm(1, bias[:, None], 1,
+ weight.view(out_channel, -1), data_col)
+ ext_module.masked_col2im_forward(
+ masked_output,
+ mask_h_idx,
+ mask_w_idx,
+ output,
+ height=out_h,
+ width=out_w,
+ channels=out_channel)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ return (None, ) * 5
+
+
+masked_conv2d = MaskedConv2dFunction.apply
+
+
+class MaskedConv2d(nn.Conv2d):
+ """A MaskedConv2d which inherits the official Conv2d.
+
+ The masked forward doesn't implement the backward function and only
+ supports the stride parameter to be 1 currently.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super(MaskedConv2d,
+ self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, input, mask=None):
+ if mask is None: # fallback to the normal Conv2d
+ return super(MaskedConv2d, self).forward(input)
+ else:
+ return masked_conv2d(input, mask, self.weight, self.bias,
+ self.padding)
diff --git a/annotator/uniformer/mmcv/ops/merge_cells.py b/annotator/uniformer/mmcv/ops/merge_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/merge_cells.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..cnn import ConvModule
+
+
+class BaseMergeCell(nn.Module):
+ """The basic class for cells used in NAS-FPN and NAS-FCOS.
+
+ BaseMergeCell takes 2 inputs. After applying convolution
+ on them, they are resized to the target size. Then,
+ they go through binary_op, which depends on the type of cell.
+ If with_out_conv is True, the result of output will go through
+ another convolution layer.
+
+ Args:
+ in_channels (int): number of input channels in out_conv layer.
+ out_channels (int): number of output channels in out_conv layer.
+ with_out_conv (bool): Whether to use out_conv layer
+ out_conv_cfg (dict): Config dict for convolution layer, which should
+ contain "groups", "kernel_size", "padding", "bias" to build
+ out_conv layer.
+ out_norm_cfg (dict): Config dict for normalization layer in out_conv.
+ out_conv_order (tuple): The order of conv/norm/activation layers in
+ out_conv.
+ with_input1_conv (bool): Whether to use convolution on input1.
+ with_input2_conv (bool): Whether to use convolution on input2.
+ input_conv_cfg (dict): Config dict for building input1_conv layer and
+ input2_conv layer, which is expected to contain the type of
+ convolution.
+ Default: None, which means using conv2d.
+ input_norm_cfg (dict): Config dict for normalization layer in
+ input1_conv and input2_conv layer. Default: None.
+ upsample_mode (str): Interpolation method used to resize the output
+ of input1_conv and input2_conv to target size. Currently, we
+ support ['nearest', 'bilinear']. Default: 'nearest'.
+ """
+
+ def __init__(self,
+ fused_channels=256,
+ out_channels=256,
+ with_out_conv=True,
+ out_conv_cfg=dict(
+ groups=1, kernel_size=3, padding=1, bias=True),
+ out_norm_cfg=None,
+ out_conv_order=('act', 'conv', 'norm'),
+ with_input1_conv=False,
+ with_input2_conv=False,
+ input_conv_cfg=None,
+ input_norm_cfg=None,
+ upsample_mode='nearest'):
+ super(BaseMergeCell, self).__init__()
+ assert upsample_mode in ['nearest', 'bilinear']
+ self.with_out_conv = with_out_conv
+ self.with_input1_conv = with_input1_conv
+ self.with_input2_conv = with_input2_conv
+ self.upsample_mode = upsample_mode
+
+ if self.with_out_conv:
+ self.out_conv = ConvModule(
+ fused_channels,
+ out_channels,
+ **out_conv_cfg,
+ norm_cfg=out_norm_cfg,
+ order=out_conv_order)
+
+ self.input1_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input1_conv else nn.Sequential()
+ self.input2_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input2_conv else nn.Sequential()
+
+ def _build_input_conv(self, channel, conv_cfg, norm_cfg):
+ return ConvModule(
+ channel,
+ channel,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True)
+
+ @abstractmethod
+ def _binary_op(self, x1, x2):
+ pass
+
+ def _resize(self, x, size):
+ if x.shape[-2:] == size:
+ return x
+ elif x.shape[-2:] < size:
+ return F.interpolate(x, size=size, mode=self.upsample_mode)
+ else:
+ assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+ kernel_size = x.shape[-1] // size[-1]
+ x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+ return x
+
+ def forward(self, x1, x2, out_size=None):
+ assert x1.shape[:2] == x2.shape[:2]
+ assert out_size is None or len(out_size) == 2
+ if out_size is None: # resize to larger one
+ out_size = max(x1.size()[2:], x2.size()[2:])
+
+ x1 = self.input1_conv(x1)
+ x2 = self.input2_conv(x2)
+
+ x1 = self._resize(x1, out_size)
+ x2 = self._resize(x2, out_size)
+
+ x = self._binary_op(x1, x2)
+ if self.with_out_conv:
+ x = self.out_conv(x)
+ return x
+
+
+class SumCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
+
+ def _binary_op(self, x1, x2):
+ return x1 + x2
+
+
+class ConcatCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(ConcatCell, self).__init__(in_channels * 2, out_channels,
+ **kwargs)
+
+ def _binary_op(self, x1, x2):
+ ret = torch.cat([x1, x2], dim=1)
+ return ret
+
+
+class GlobalPoolingCell(BaseMergeCell):
+
+ def __init__(self, in_channels=None, out_channels=None, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def _binary_op(self, x1, x2):
+ x2_att = self.global_pool(x2).sigmoid()
+ return x2 + x2_att * x1
diff --git a/annotator/uniformer/mmcv/ops/modulated_deform_conv.py b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..75559579cf053abcc99538606cbb88c723faf783
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
@@ -0,0 +1,282 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
+
+
+class ModulatedDeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, offset, mask, weight, bias, stride, padding,
+ dilation, groups, deform_groups):
+ input_tensors = [input, offset, mask, weight]
+ if bias is not None:
+ input_tensors.append(bias)
+ return g.op(
+ 'mmcv::MMCVModulatedDeformConv2d',
+ *input_tensors,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(0) # fake tensor
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ ext_module.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ grad_output = grad_output.contiguous()
+ ext_module.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
+
+
+class ModulatedDeformConv2d(nn.Module):
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='ModulatedDeformConv2d')
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=True):
+ super(ModulatedDeformConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+
+@CONV_LAYERS.register_module('DCNv2')
+class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+ layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int): Same as nn.Conv2d, while tuple is not supported.
+ padding (int): Same as nn.Conv2d, while tuple is not supported.
+ dilation (int): Same as nn.Conv2d, while tuple is not supported.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConv2dPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52dda18b41705705b47dd0e995b124048c16fba
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
@@ -0,0 +1,358 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, once_differentiable
+
+from annotator.uniformer.mmcv import deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import constant_init, xavier_init
+from annotator.uniformer.mmcv.cnn.bricks.registry import ATTENTION
+from annotator.uniformer.mmcv.runner import BaseModule
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+
+
+class MultiScaleDeformableAttnFunction(Function):
+
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
+ sampling_locations, attention_weights, im2col_step):
+ """GPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+ im2col_step (Tensor): The step used in image to column.
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ ctx.im2col_step = im2col_step
+ output = ext_module.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step=ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes,
+ value_level_start_index, sampling_locations,
+ attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ """GPU version of backward function.
+
+ Args:
+ grad_output (Tensor): Gradient
+ of output tensor of forward.
+
+ Returns:
+ Tuple[Tensor]: Gradient
+ of input tensors in forward.
+ """
+ value, value_spatial_shapes, value_level_start_index,\
+ sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value = torch.zeros_like(value)
+ grad_sampling_loc = torch.zeros_like(sampling_locations)
+ grad_attn_weight = torch.zeros_like(attention_weights)
+
+ ext_module.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output.contiguous(),
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ im2col_step=ctx.im2col_step)
+
+ return grad_value, None, None, \
+ grad_sampling_loc, grad_attn_weight, None
+
+
+def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
+ sampling_locations, attention_weights):
+ """CPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ bs, _, num_heads, embed_dims = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ =\
+ sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+ dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
+ # bs, H_*W_, num_heads, embed_dims ->
+ # bs, H_*W_, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, H_*W_ ->
+ # bs*num_heads, embed_dims, H_, W_
+ value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
+ bs * num_heads, embed_dims, H_, W_)
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :,
+ level].transpose(1, 2).flatten(0, 1)
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
+ attention_weights).sum(-1).view(bs, num_heads * embed_dims,
+ num_queries)
+ return output.transpose(1, 2).contiguous()
+
+
+@ATTENTION.register_module()
+class MultiScaleDeformableAttention(BaseModule):
+ """An attention module used in Deformable-Detr.
+
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for
+ each query in each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.1,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ if embed_dims % num_heads != 0:
+ raise ValueError(f'embed_dims must be divisible by num_heads, '
+ f'but got {embed_dims} and {num_heads}')
+ dim_per_head = embed_dims // num_heads
+ self.norm_cfg = norm_cfg
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ # you'd better set dim_per_head to a power of 2
+ # which is more efficient in the CUDA implementation
+ def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+ if not _is_power_of_2(dim_per_head):
+ warnings.warn(
+ "You'd better set embed_dims in "
+ 'MultiScaleDeformAttention to make '
+ 'the dimension of each attention head a power of 2 '
+ 'which is more efficient in our CUDA implementation.')
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dims,
+ num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self):
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.)
+ thetas = torch.arange(
+ self.num_heads,
+ dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init /
+ grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.num_heads, 1, 1,
+ 2).repeat(1, self.num_levels, self.num_points, 1)
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0., bias=0.)
+ xavier_init(self.value_proj, distribution='uniform', bias=0.)
+ xavier_init(self.output_proj, distribution='uniform', bias=0.)
+ self._is_init = True
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformableAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_padding_mask=None,
+ reference_points=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ **kwargs):
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`. Default
+ None.
+ reference_points (Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+
+ if value is None:
+ value = query
+
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets \
+ / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.num_points \
+ * reference_points[:, :, None, :, None, 2:] \
+ * 0.5
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be'
+ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/annotator/uniformer/mmcv/ops/nms.py b/annotator/uniformer/mmcv/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9634281f486ab284091786886854c451368052
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/nms.py
@@ -0,0 +1,417 @@
+import os
+
+import numpy as np
+import torch
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
+
+
+# This function is modified from: https://github.com/pytorch/vision/
+class NMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ is_filtering_by_score = score_threshold > 0
+ if is_filtering_by_score:
+ valid_mask = scores > score_threshold
+ bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+ valid_inds = torch.nonzero(
+ valid_mask, as_tuple=False).squeeze(dim=1)
+
+ inds = ext_module.nms(
+ bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+
+ if max_num > 0:
+ inds = inds[:max_num]
+ if is_filtering_by_score:
+ inds = valid_inds[inds]
+ return inds
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ # TensorRT nms plugin is aligned with original nms in ONNXRuntime
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ if has_custom_op and (not is_trt_backend):
+ return g.op(
+ 'mmcv::NonMaxSuppression',
+ bboxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ offset_i=int(offset))
+ else:
+ from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+ from ..onnx.onnx_utils.symbolic_helper import _size_helper
+
+ boxes = unsqueeze(g, bboxes, 0)
+ scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
+
+ if max_num > 0:
+ max_num = g.op(
+ 'Constant',
+ value_t=torch.tensor(max_num, dtype=torch.long))
+ else:
+ dim = g.op('Constant', value_t=torch.tensor(0))
+ max_num = _size_helper(g, bboxes, dim)
+ max_output_per_class = max_num
+ iou_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([iou_threshold], dtype=torch.float))
+ score_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([score_threshold], dtype=torch.float))
+ nms_out = g.op('NonMaxSuppression', boxes, scores,
+ max_output_per_class, iou_threshold,
+ score_threshold)
+ return squeeze(
+ g,
+ select(
+ g, nms_out, 1,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor([2], dtype=torch.long))), 1)
+
+
+class SoftNMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ inds = ext_module.softnms(
+ boxes.cpu(),
+ scores.cpu(),
+ dets.cpu(),
+ iou_threshold=float(iou_threshold),
+ sigma=float(sigma),
+ min_score=float(min_score),
+ method=int(method),
+ offset=int(offset))
+ return dets, inds
+
+ @staticmethod
+ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ from packaging import version
+ assert version.parse(torch.__version__) >= version.parse('1.7.0')
+ nms_out = g.op(
+ 'mmcv::SoftNonMaxSuppression',
+ boxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ sigma_f=float(sigma),
+ min_score_f=float(min_score),
+ method_i=int(method),
+ offset_i=int(offset),
+ outputs=2)
+ return nms_out
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
+ """Dispatch to either CPU or GPU NMS implementations.
+
+ The input can be either torch tensor or numpy array. GPU NMS will be used
+ if the input is gpu tensor, otherwise CPU NMS
+ will be used. The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ score_threshold (float): score threshold for NMS.
+ max_num (int): maximum number of boxes after NMS.
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9],
+ >>> [49.3, 32.9, 51.0, 35.3],
+ >>> [49.2, 31.8, 51.0, 35.4],
+ >>> [35.1, 11.5, 39.1, 15.7],
+ >>> [35.6, 11.8, 39.3, 14.2],
+ >>> [35.3, 11.5, 39.9, 14.5],
+ >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\
+ dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = nms(boxes, scores, iou_threshold)
+ >>> assert len(inds) == len(dets) == 3
+ """
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+
+ if torch.__version__ == 'parrots':
+ indata_list = [boxes, scores]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'offset': int(offset)
+ }
+ inds = ext_module.nms(*indata_list, **indata_dict)
+ else:
+ inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+ score_threshold, max_num)
+ dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def soft_nms(boxes,
+ scores,
+ iou_threshold=0.3,
+ sigma=0.5,
+ min_score=1e-3,
+ method='linear',
+ offset=0):
+ """Dispatch to only CPU Soft NMS implementations.
+
+ The input can be either a torch tensor or numpy array.
+ The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ sigma (float): hyperparameter for gaussian method
+ min_score (float): score filter threshold
+ method (str): either 'linear' or 'gaussian'
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[4., 3., 5., 3.],
+ >>> [4., 3., 5., 4.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5)
+ >>> assert len(inds) == len(dets) == 5
+ """
+
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+ method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2}
+ assert method in method_dict.keys()
+
+ if torch.__version__ == 'parrots':
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'sigma': float(sigma),
+ 'min_score': min_score,
+ 'method': method_dict[method],
+ 'offset': int(offset)
+ }
+ inds = ext_module.softnms(*indata_list, **indata_dict)
+ else:
+ dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
+ float(iou_threshold), float(sigma),
+ float(min_score), method_dict[method],
+ int(offset))
+
+ dets = dets[:inds.size(0)]
+
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+ else:
+ return dets.to(device=boxes.device), inds.to(device=boxes.device)
+
+
+def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
+ """Performs non-maximum suppression in a batched fashion.
+
+ Modified from https://github.com/pytorch/vision/blob
+ /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
+ In order to perform NMS independently per class, we add an offset to all
+ the boxes. The offset is dependent only on the class idx, and is large
+ enough so that boxes from different classes do not overlap.
+
+ Arguments:
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ idxs (torch.Tensor): each index value correspond to a bbox cluster,
+ and NMS will not be applied between elements of different idxs,
+ shape (N, ).
+ nms_cfg (dict): specify nms type and other parameters like iou_thr.
+ Possible keys includes the following.
+
+ - iou_thr (float): IoU threshold used for NMS.
+ - split_thr (float): threshold number of boxes. In some cases the
+ number of boxes is large (e.g., 200k). To avoid OOM during
+ training, the users could set `split_thr` to a small value.
+ If the number of boxes is greater than the threshold, it will
+ perform NMS on each group of boxes separately and sequentially.
+ Defaults to 10000.
+ class_agnostic (bool): if true, nms is class agnostic,
+ i.e. IoU thresholding happens over all boxes,
+ regardless of the predicted class.
+
+ Returns:
+ tuple: kept dets and indice.
+ """
+ nms_cfg_ = nms_cfg.copy()
+ class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
+ if class_agnostic:
+ boxes_for_nms = boxes
+ else:
+ max_coordinate = boxes.max()
+ offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
+ boxes_for_nms = boxes + offsets[:, None]
+
+ nms_type = nms_cfg_.pop('type', 'nms')
+ nms_op = eval(nms_type)
+
+ split_thr = nms_cfg_.pop('split_thr', 10000)
+ # Won't split to multiple nms nodes when exporting to onnx
+ if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
+ dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
+ boxes = boxes[keep]
+ # -1 indexing works abnormal in TensorRT
+ # This assumes `dets` has 5 dimensions where
+ # the last dimension is score.
+ # TODO: more elegant way to handle the dimension issue.
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores = dets[:, 4]
+ else:
+ max_num = nms_cfg_.pop('max_num', -1)
+ total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores_after_nms = scores.new_zeros(scores.size())
+ for id in torch.unique(idxs):
+ mask = (idxs == id).nonzero(as_tuple=False).view(-1)
+ dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
+ total_mask[mask[keep]] = True
+ scores_after_nms[mask[keep]] = dets[:, -1]
+ keep = total_mask.nonzero(as_tuple=False).view(-1)
+
+ scores, inds = scores_after_nms[keep].sort(descending=True)
+ keep = keep[inds]
+ boxes = boxes[keep]
+
+ if max_num > 0:
+ keep = keep[:max_num]
+ boxes = boxes[:max_num]
+ scores = scores[:max_num]
+
+ return torch.cat([boxes, scores[:, None]], -1), keep
+
+
+def nms_match(dets, iou_threshold):
+ """Matched dets into different groups by NMS.
+
+ NMS match is Similar to NMS but when a bbox is suppressed, nms match will
+ record the indice of suppressed bbox and form a group with the indice of
+ kept bbox. In each group, indice is sorted as score order.
+
+ Arguments:
+ dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
+ iou_thr (float): IoU thresh for NMS.
+
+ Returns:
+ List[torch.Tensor | np.ndarray]: The outer list corresponds different
+ matched group, the inner Tensor corresponds the indices for a group
+ in score order.
+ """
+ if dets.shape[0] == 0:
+ matched = []
+ else:
+ assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
+ f'but get {dets.shape}'
+ if isinstance(dets, torch.Tensor):
+ dets_t = dets.detach().cpu()
+ else:
+ dets_t = torch.from_numpy(dets)
+ indata_list = [dets_t]
+ indata_dict = {'iou_threshold': float(iou_threshold)}
+ matched = ext_module.nms_match(*indata_list, **indata_dict)
+ if torch.__version__ == 'parrots':
+ matched = matched.tolist()
+
+ if isinstance(dets, torch.Tensor):
+ return [dets.new_tensor(m, dtype=torch.long) for m in matched]
+ else:
+ return [np.array(m, dtype=np.int) for m in matched]
+
+
+def nms_rotated(dets, scores, iou_threshold, labels=None):
+ """Performs non-maximum suppression (NMS) on the rotated boxes according to
+ their intersection-over-union (IoU).
+
+ Rotated NMS iteratively removes lower scoring rotated boxes which have an
+ IoU greater than iou_threshold with another (higher scoring) rotated box.
+
+ Args:
+ boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \
+ be in (x_ctr, y_ctr, width, height, angle_radian) format.
+ scores (Tensor): scores in shape (N, ).
+ iou_threshold (float): IoU thresh for NMS.
+ labels (Tensor): boxes' label in shape (N,).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ """
+ if dets.shape[0] == 0:
+ return dets, None
+ multi_label = labels is not None
+ if multi_label:
+ dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
+ else:
+ dets_wl = dets
+ _, order = scores.sort(0, descending=True)
+ dets_sorted = dets_wl.index_select(0, order)
+
+ if torch.__version__ == 'parrots':
+ keep_inds = ext_module.nms_rotated(
+ dets_wl,
+ scores,
+ order,
+ dets_sorted,
+ iou_threshold=iou_threshold,
+ multi_label=multi_label)
+ else:
+ keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
+ iou_threshold, multi_label)
+ dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
+ dim=1)
+ return dets, keep_inds
diff --git a/annotator/uniformer/mmcv/ops/pixel_group.py b/annotator/uniformer/mmcv/ops/pixel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/pixel_group.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
+
+
+def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
+ kernel_region_num, distance_threshold):
+ """Group pixels into text instances, which is widely used text detection
+ methods.
+
+ Arguments:
+ score (np.array or Tensor): The foreground score with size hxw.
+ mask (np.array or Tensor): The foreground mask with size hxw.
+ embedding (np.array or Tensor): The embedding with size hxwxc to
+ distinguish instances.
+ kernel_label (np.array or Tensor): The instance kernel index with
+ size hxw.
+ kernel_contour (np.array or Tensor): The kernel contour with size hxw.
+ kernel_region_num (int): The instance kernel region number.
+ distance_threshold (float): The embedding distance threshold between
+ kernel and pixel in one instance.
+
+ Returns:
+ pixel_assignment (List[List[float]]): The instance coordinate list.
+ Each element consists of averaged confidence, pixel number, and
+ coordinates (x_i, y_i for all pixels) in order.
+ """
+ assert isinstance(score, (torch.Tensor, np.ndarray))
+ assert isinstance(mask, (torch.Tensor, np.ndarray))
+ assert isinstance(embedding, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_region_num, int)
+ assert isinstance(distance_threshold, float)
+
+ if isinstance(score, np.ndarray):
+ score = torch.from_numpy(score)
+ if isinstance(mask, np.ndarray):
+ mask = torch.from_numpy(mask)
+ if isinstance(embedding, np.ndarray):
+ embedding = torch.from_numpy(embedding)
+ if isinstance(kernel_label, np.ndarray):
+ kernel_label = torch.from_numpy(kernel_label)
+ if isinstance(kernel_contour, np.ndarray):
+ kernel_contour = torch.from_numpy(kernel_contour)
+
+ if torch.__version__ == 'parrots':
+ label = ext_module.pixel_group(
+ score,
+ mask,
+ embedding,
+ kernel_label,
+ kernel_contour,
+ kernel_region_num=kernel_region_num,
+ distance_threshold=distance_threshold)
+ label = label.tolist()
+ label = label[0]
+ list_index = kernel_region_num
+ pixel_assignment = []
+ for x in range(kernel_region_num):
+ pixel_assignment.append(
+ np.array(
+ label[list_index:list_index + int(label[x])],
+ dtype=np.float))
+ list_index = list_index + int(label[x])
+ else:
+ pixel_assignment = ext_module.pixel_group(score, mask, embedding,
+ kernel_label, kernel_contour,
+ kernel_region_num,
+ distance_threshold)
+ return pixel_assignment
diff --git a/annotator/uniformer/mmcv/ops/point_sample.py b/annotator/uniformer/mmcv/ops/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..267f4b3c56630acd85f9bdc630b7be09abab0aba
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/point_sample.py
@@ -0,0 +1,336 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+
+from os import path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+
+
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input’s
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input’s corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+ # Apply default for grid_sample function zero padding
+ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+ im_padded = im_padded.view(n, c, -1)
+
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
+
+def is_in_onnx_export_without_custom_ops():
+ from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ return torch.onnx.is_in_onnx_export(
+ ) and not osp.exists(ort_custom_op_path)
+
+
+def normalize(grid):
+ """Normalize input grid from [-1, 1] to [0, 1]
+ Args:
+ grid (Tensor): The grid to be normalize, range [-1, 1].
+ Returns:
+ Tensor: Normalized grid, range [0, 1].
+ """
+
+ return (grid + 1.0) / 2.0
+
+
+def denormalize(grid):
+ """Denormalize input grid from range [0, 1] to [-1, 1]
+ Args:
+ grid (Tensor): The grid to be denormalize, range [0, 1].
+ Returns:
+ Tensor: Denormalized grid, range [-1, 1].
+ """
+
+ return grid * 2.0 - 1.0
+
+
+def generate_grid(num_grid, size, device):
+ """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
+ space.
+
+ Args:
+ num_grid (int): The number of grids to sample, one for each region.
+ size (tuple(int, int)): The side size of the regular grid.
+ device (torch.device): Desired device of returned tensor.
+
+ Returns:
+ (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
+ contains coordinates for the regular grids.
+ """
+
+ affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
+ grid = F.affine_grid(
+ affine_trans, torch.Size((1, 1, *size)), align_corners=False)
+ grid = normalize(grid)
+ return grid.view(1, -1, 2).expand(num_grid, -1, -1)
+
+
+def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ Returns:
+ Tensor: Image based absolute point coordinates, shape (N, P, 2)
+ """
+
+ with torch.no_grad():
+ assert rel_roi_points.size(0) == rois.size(0)
+ assert rois.dim() == 2
+ assert rel_roi_points.dim() == 3
+ assert rel_roi_points.size(2) == 2
+ # remove batch idx
+ if rois.size(1) == 5:
+ rois = rois[:, 1:]
+ abs_img_points = rel_roi_points.clone()
+ # To avoid an error during exporting to onnx use independent
+ # variables instead inplace computation
+ xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+ ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+ xs += rois[:, None, 0]
+ ys += rois[:, None, 1]
+ abs_img_points = torch.stack([xs, ys], dim=2)
+ return abs_img_points
+
+
+def get_shape_from_feature_map(x):
+ """Get spatial resolution of input feature map considering exporting to
+ onnx mode.
+
+ Args:
+ x (torch.Tensor): Input tensor, shape (N, C, H, W)
+ Returns:
+ torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+ """
+ if torch.onnx.is_in_onnx_export():
+ img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+ x.device).float()
+ else:
+ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+ x.device).float()
+ return img_shape
+
+
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
+ """Convert image based absolute point coordinates to image based relative
+ coordinates for sampling.
+
+ Args:
+ abs_img_points (Tensor): Image based absolute point coordinates,
+ shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ assert (isinstance(img, tuple) and len(img) == 2) or \
+ (isinstance(img, torch.Tensor) and len(img.shape) == 4)
+
+ if isinstance(img, tuple):
+ h, w = img
+ scale = torch.tensor([w, h],
+ dtype=torch.float,
+ device=abs_img_points.device)
+ scale = scale.view(1, 1, 2)
+ else:
+ scale = get_shape_from_feature_map(img)
+
+ return abs_img_points / scale * spatial_scale
+
+
+def rel_roi_point_to_rel_img_point(rois,
+ rel_roi_points,
+ img,
+ spatial_scale=1.):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
+ rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
+ spatial_scale)
+
+ return rel_img_point
+
+
+def point_sample(input, points, align_corners=False, **kwargs):
+ """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
+ Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
+ lie inside ``[0, 1] x [0, 1]`` square.
+
+ Args:
+ input (Tensor): Feature map, shape (N, C, H, W).
+ points (Tensor): Image based absolute point coordinates (normalized),
+ range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
+ align_corners (bool): Whether align_corners. Default: False
+
+ Returns:
+ Tensor: Features of `point` on `input`, shape (N, C, P) or
+ (N, C, Hgrid, Wgrid).
+ """
+
+ add_dim = False
+ if points.dim() == 3:
+ add_dim = True
+ points = points.unsqueeze(2)
+ if is_in_onnx_export_without_custom_ops():
+ # If custom ops for onnx runtime not compiled use python
+ # implementation of grid_sample function to make onnx graph
+ # with supported nodes
+ output = bilinear_grid_sample(
+ input, denormalize(points), align_corners=align_corners)
+ else:
+ output = F.grid_sample(
+ input, denormalize(points), align_corners=align_corners, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+class SimpleRoIAlign(nn.Module):
+
+ def __init__(self, output_size, spatial_scale, aligned=True):
+ """Simple RoI align in PointRend, faster than standard RoIAlign.
+
+ Args:
+ output_size (tuple[int]): h, w
+ spatial_scale (float): scale the input boxes by this number
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection, align_corners=True will be used in F.grid_sample.
+ If True, align the results more perfectly.
+ """
+
+ super(SimpleRoIAlign, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ # to be consistent with other RoI ops
+ self.use_torchvision = False
+ self.aligned = aligned
+
+ def forward(self, features, rois):
+ num_imgs = features.size(0)
+ num_rois = rois.size(0)
+ rel_roi_points = generate_grid(
+ num_rois, self.output_size, device=rois.device)
+
+ if torch.onnx.is_in_onnx_export():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, features, self.spatial_scale)
+ rel_img_points = rel_img_points.reshape(num_imgs, -1,
+ *rel_img_points.shape[1:])
+ point_feats = point_sample(
+ features, rel_img_points, align_corners=not self.aligned)
+ point_feats = point_feats.transpose(1, 2)
+ else:
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = features[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat,
+ self.spatial_scale).unsqueeze(0)
+ point_feat = point_sample(
+ feat, rel_img_points, align_corners=not self.aligned)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+
+ point_feats = torch.cat(point_feats, dim=0)
+
+ channels = features.size(1)
+ roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
+
+ return roi_feats
+
+ def __repr__(self):
+ format_str = self.__class__.__name__
+ format_str += '(output_size={}, spatial_scale={}'.format(
+ self.output_size, self.spatial_scale)
+ return format_str
diff --git a/annotator/uniformer/mmcv/ops/points_in_boxes.py b/annotator/uniformer/mmcv/ops/points_in_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/points_in_boxes.py
@@ -0,0 +1,133 @@
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
+ 'points_in_boxes_all_forward'
+])
+
+
+def points_in_boxes_part(points, boxes):
+ """Find the box in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
+ LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points),
+ dtype=torch.int).fill_(-1)
+
+ # If manually put the tensor 'points' or 'boxes' on a device
+ # which is not the current device, some temporary variables
+ # will be created on the current device in the cuda op,
+ # and the output will be incorrect.
+ # Therefore, we force the current device to be the same
+ # as the device of the tensors if it was not.
+ # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
+ # for the incorrect output before the fix.
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_part_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
+
+
+def points_in_boxes_cpu(points, boxes):
+ """Find all boxes in which each point is (CPU). The CPU version of
+ :meth:`points_in_boxes_all`.
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in
+ LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ point_indices = points.new_zeros((batch_size, num_boxes, num_points),
+ dtype=torch.int)
+ for b in range(batch_size):
+ ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
+ points[b].float().contiguous(),
+ point_indices[b])
+ point_indices = point_indices.transpose(1, 2)
+
+ return point_indices
+
+
+def points_in_boxes_all(points, boxes):
+ """Find all boxes in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert boxes.shape[0] == points.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {boxes.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
+ dtype=torch.int).fill_(0)
+
+ # Same reason as line 25-32
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_all_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
diff --git a/annotator/uniformer/mmcv/ops/points_sampler.py b/annotator/uniformer/mmcv/ops/points_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a802a74fd6c3610d9ae178e6201f47423eca7ad1
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/points_sampler.py
@@ -0,0 +1,177 @@
+from typing import List
+
+import torch
+from torch import nn as nn
+
+from annotator.uniformer.mmcv.runner import force_fp32
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+
+
+def calc_square_dist(point_feat_a, point_feat_b, norm=True):
+ """Calculating square distance between a and b.
+
+ Args:
+ point_feat_a (Tensor): (B, N, C) Feature vector of each point.
+ point_feat_b (Tensor): (B, M, C) Feature vector of each point.
+ norm (Bool, optional): Whether to normalize the distance.
+ Default: True.
+
+ Returns:
+ Tensor: (B, N, M) Distance between each pair points.
+ """
+ num_channel = point_feat_a.shape[-1]
+ # [bs, n, 1]
+ a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
+ # [bs, 1, m]
+ b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
+
+ corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
+
+ dist = a_square + b_square - 2 * corr_matrix
+ if norm:
+ dist = torch.sqrt(dist) / num_channel
+ return dist
+
+
+def get_sampler_cls(sampler_type):
+ """Get the type and mode of points sampler.
+
+ Args:
+ sampler_type (str): The type of points sampler.
+ The valid value are "D-FPS", "F-FPS", or "FS".
+
+ Returns:
+ class: Points sampler type.
+ """
+ sampler_mappings = {
+ 'D-FPS': DFPSSampler,
+ 'F-FPS': FFPSSampler,
+ 'FS': FSSampler,
+ }
+ try:
+ return sampler_mappings[sampler_type]
+ except KeyError:
+ raise KeyError(
+ f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
+ {sampler_type}')
+
+
+class PointsSampler(nn.Module):
+ """Points sampling.
+
+ Args:
+ num_point (list[int]): Number of sample points.
+ fps_mod_list (list[str], optional): Type of FPS method, valid mod
+ ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
+ F-FPS: using feature distances for FPS.
+ D-FPS: using Euclidean distances of points for FPS.
+ FS: using F-FPS and D-FPS simultaneously.
+ fps_sample_range_list (list[int], optional):
+ Range of points to apply FPS. Default: [-1].
+ """
+
+ def __init__(self,
+ num_point: List[int],
+ fps_mod_list: List[str] = ['D-FPS'],
+ fps_sample_range_list: List[int] = [-1]):
+ super().__init__()
+ # FPS would be applied to different fps_mod in the list,
+ # so the length of the num_point should be equal to
+ # fps_mod_list and fps_sample_range_list.
+ assert len(num_point) == len(fps_mod_list) == len(
+ fps_sample_range_list)
+ self.num_point = num_point
+ self.fps_sample_range_list = fps_sample_range_list
+ self.samplers = nn.ModuleList()
+ for fps_mod in fps_mod_list:
+ self.samplers.append(get_sampler_cls(fps_mod)())
+ self.fp16_enabled = False
+
+ @force_fp32()
+ def forward(self, points_xyz, features):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, npoint, sample_num) Indices of sampled points.
+ """
+ indices = []
+ last_fps_end_index = 0
+
+ for fps_sample_range, sampler, npoint in zip(
+ self.fps_sample_range_list, self.samplers, self.num_point):
+ assert fps_sample_range < points_xyz.shape[1]
+
+ if fps_sample_range == -1:
+ sample_points_xyz = points_xyz[:, last_fps_end_index:]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:]
+ else:
+ sample_features = None
+ else:
+ sample_points_xyz = \
+ points_xyz[:, last_fps_end_index:fps_sample_range]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:
+ fps_sample_range]
+ else:
+ sample_features = None
+
+ fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
+ npoint)
+
+ indices.append(fps_idx + last_fps_end_index)
+ last_fps_end_index += fps_sample_range
+ indices = torch.cat(indices, dim=1)
+
+ return indices
+
+
+class DFPSSampler(nn.Module):
+ """Using Euclidean distances of points for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with D-FPS."""
+ fps_idx = furthest_point_sample(points.contiguous(), npoint)
+ return fps_idx
+
+
+class FFPSSampler(nn.Module):
+ """Using feature distances for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with F-FPS."""
+ assert features is not None, \
+ 'feature input to FFPS_Sampler should not be None'
+ features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
+ features_dist = calc_square_dist(
+ features_for_fps, features_for_fps, norm=False)
+ fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
+ return fps_idx
+
+
+class FSSampler(nn.Module):
+ """Using F-FPS and D-FPS simultaneously."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with FS_Sampling."""
+ assert features is not None, \
+ 'feature input to FS_Sampler should not be None'
+ ffps_sampler = FFPSSampler()
+ dfps_sampler = DFPSSampler()
+ fps_idx_ffps = ffps_sampler(points, features, npoint)
+ fps_idx_dfps = dfps_sampler(points, features, npoint)
+ fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
+ return fps_idx
diff --git a/annotator/uniformer/mmcv/ops/psa_mask.py b/annotator/uniformer/mmcv/ops/psa_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/psa_mask.py
@@ -0,0 +1,92 @@
+# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['psamask_forward', 'psamask_backward'])
+
+
+class PSAMaskFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, psa_type, mask_size):
+ return g.op(
+ 'mmcv::MMCVPSAMask',
+ input,
+ psa_type_i=psa_type,
+ mask_size_i=mask_size)
+
+ @staticmethod
+ def forward(ctx, input, psa_type, mask_size):
+ ctx.psa_type = psa_type
+ ctx.mask_size = _pair(mask_size)
+ ctx.save_for_backward(input)
+
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ assert channels == h_mask * w_mask
+ output = input.new_zeros(
+ (batch_size, h_feature * w_feature, h_feature, w_feature))
+
+ ext_module.psamask_forward(
+ input,
+ output,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors[0]
+ psa_type = ctx.psa_type
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ grad_input = grad_output.new_zeros(
+ (batch_size, channels, h_feature, w_feature))
+ ext_module.psamask_backward(
+ grad_output,
+ grad_input,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return grad_input, None, None, None
+
+
+psa_mask = PSAMaskFunction.apply
+
+
+class PSAMask(nn.Module):
+
+ def __init__(self, psa_type, mask_size=None):
+ super(PSAMask, self).__init__()
+ assert psa_type in ['collect', 'distribute']
+ if psa_type == 'collect':
+ psa_type_enum = 0
+ else:
+ psa_type_enum = 1
+ self.psa_type_enum = psa_type_enum
+ self.mask_size = mask_size
+ self.psa_type = psa_type
+
+ def forward(self, input):
+ return psa_mask(input, self.psa_type_enum, self.mask_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(psa_type={self.psa_type}, '
+ s += f'mask_size={self.mask_size})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/roi_align.py b/annotator/uniformer/mmcv/ops/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_align.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import deprecated_api_warning, ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_align_forward', 'roi_align_backward'])
+
+
+class RoIAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
+ pool_mode, aligned):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ if has_custom_op:
+ return g.op(
+ 'mmcv::MMCVRoiAlign',
+ input,
+ rois,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sampling_ratio,
+ mode_s=pool_mode,
+ aligned_i=aligned)
+ else:
+ from torch.onnx.symbolic_opset9 import sub, squeeze
+ from torch.onnx.symbolic_helper import _slice_helper
+ from torch.onnx import TensorProtoDataType
+ # batch_indices = rois[:, 0].long()
+ batch_indices = _slice_helper(
+ g, rois, axes=[1], starts=[0], ends=[1])
+ batch_indices = squeeze(g, batch_indices, 1)
+ batch_indices = g.op(
+ 'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
+ # rois = rois[:, 1:]
+ rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
+ if aligned:
+ # rois -= 0.5/spatial_scale
+ aligned_offset = g.op(
+ 'Constant',
+ value_t=torch.tensor([0.5 / spatial_scale],
+ dtype=torch.float32))
+ rois = sub(g, rois, aligned_offset)
+ # roi align
+ return g.op(
+ 'RoiAlign',
+ input,
+ rois,
+ batch_indices,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=max(0, sampling_ratio),
+ mode_s=pool_mode)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.sampling_ratio = sampling_ratio
+ assert pool_mode in ('max', 'avg')
+ ctx.pool_mode = 0 if pool_mode == 'max' else 1
+ ctx.aligned = aligned
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ if ctx.pool_mode == 0:
+ argmax_y = input.new_zeros(output_shape)
+ argmax_x = input.new_zeros(output_shape)
+ else:
+ argmax_y = input.new_zeros(0)
+ argmax_x = input.new_zeros(0)
+
+ ext_module.roi_align_forward(
+ input,
+ rois,
+ output,
+ argmax_y,
+ argmax_x,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+
+ ctx.save_for_backward(rois, argmax_y, argmax_x)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax_y, argmax_x = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous.
+ grad_output = grad_output.contiguous()
+ ext_module.roi_align_backward(
+ grad_output,
+ rois,
+ argmax_y,
+ argmax_x,
+ grad_input,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+ return grad_input, None, None, None, None, None, None
+
+
+roi_align = RoIAlignFunction.apply
+
+
+class RoIAlign(nn.Module):
+ """RoI align pooling layer.
+
+ Args:
+ output_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sampling_ratio (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ use_torchvision (bool): whether to use roi_align from torchvision.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'out_size': 'output_size',
+ 'sample_num': 'sampling_ratio'
+ },
+ cls_name='RoIAlign')
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True,
+ use_torchvision=False):
+ super(RoIAlign, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.pool_mode = pool_mode
+ self.aligned = aligned
+ self.use_torchvision = use_torchvision
+
+ def forward(self, input, rois):
+ """
+ Args:
+ input: NCHW images
+ rois: Bx5 boxes. First column is the index into N.\
+ The other 4 columns are xyxy.
+ """
+ if self.use_torchvision:
+ from torchvision.ops import roi_align as tv_roi_align
+ if 'aligned' in tv_roi_align.__code__.co_varnames:
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.aligned)
+ else:
+ if self.aligned:
+ rois -= rois.new_tensor([0.] +
+ [0.5 / self.spatial_scale] * 4)
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio)
+ else:
+ return roi_align(input, rois, self.output_size, self.spatial_scale,
+ self.sampling_ratio, self.pool_mode, self.aligned)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale}, '
+ s += f'sampling_ratio={self.sampling_ratio}, '
+ s += f'pool_mode={self.pool_mode}, '
+ s += f'aligned={self.aligned}, '
+ s += f'use_torchvision={self.use_torchvision})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/roi_align_rotated.py b/annotator/uniformer/mmcv/ops/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_align_rotated.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
+
+
+class RoIAlignRotatedFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, rois, out_size, spatial_scale, sample_num,
+ aligned, clockwise):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ return g.op(
+ 'mmcv::MMCVRoIAlignRotated',
+ features,
+ rois,
+ output_height_i=out_h,
+ output_width_i=out_h,
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sample_num,
+ aligned_i=aligned,
+ clockwise_i=clockwise)
+
+ @staticmethod
+ def forward(ctx,
+ features,
+ rois,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ ctx.spatial_scale = spatial_scale
+ ctx.sample_num = sample_num
+ ctx.aligned = aligned
+ ctx.clockwise = clockwise
+ ctx.save_for_backward(rois)
+ ctx.feature_size = features.size()
+
+ batch_size, num_channels, data_height, data_width = features.size()
+ num_rois = rois.size(0)
+
+ output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+ ext_module.roi_align_rotated_forward(
+ features,
+ rois,
+ output,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ feature_size = ctx.feature_size
+ spatial_scale = ctx.spatial_scale
+ aligned = ctx.aligned
+ clockwise = ctx.clockwise
+ sample_num = ctx.sample_num
+ rois = ctx.saved_tensors[0]
+ assert feature_size is not None
+ batch_size, num_channels, data_height, data_width = feature_size
+
+ out_w = grad_output.size(3)
+ out_h = grad_output.size(2)
+
+ grad_input = grad_rois = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = rois.new_zeros(batch_size, num_channels, data_height,
+ data_width)
+ ext_module.roi_align_rotated_backward(
+ grad_output.contiguous(),
+ rois,
+ grad_input,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return grad_input, grad_rois, None, None, None, None, None
+
+
+roi_align_rotated = RoIAlignRotatedFunction.apply
+
+
+class RoIAlignRotated(nn.Module):
+ """RoI align pooling layer for rotated proposals.
+
+ It accepts a feature map of shape (N, C, H, W) and rois with shape
+ (n, 6) with each roi decoded as (batch_index, center_x, center_y,
+ w, h, angle). The angle is in radian.
+
+ Args:
+ out_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sample_num (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ Default: True.
+ clockwise (bool): If True, the angle in each proposal follows a
+ clockwise fashion in image space, otherwise, the angle is
+ counterclockwise. Default: False.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ def __init__(self,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ super(RoIAlignRotated, self).__init__()
+
+ self.out_size = out_size
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.aligned = aligned
+ self.clockwise = clockwise
+
+ def forward(self, features, rois):
+ return RoIAlignRotatedFunction.apply(features, rois, self.out_size,
+ self.spatial_scale,
+ self.sample_num, self.aligned,
+ self.clockwise)
diff --git a/annotator/uniformer/mmcv/ops/roi_pool.py b/annotator/uniformer/mmcv/ops/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_pool.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_pool_forward', 'roi_pool_backward'])
+
+
+class RoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale):
+ return g.op(
+ 'MaxRoiPool',
+ input,
+ rois,
+ pooled_shape_i=output_size,
+ spatial_scale_f=spatial_scale)
+
+ @staticmethod
+ def forward(ctx, input, rois, output_size, spatial_scale=1.0):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ argmax = input.new_zeros(output_shape, dtype=torch.int)
+
+ ext_module.roi_pool_forward(
+ input,
+ rois,
+ output,
+ argmax,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ ctx.save_for_backward(rois, argmax)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+
+ ext_module.roi_pool_backward(
+ grad_output,
+ rois,
+ argmax,
+ grad_input,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ return grad_input, None, None, None
+
+
+roi_pool = RoIPoolFunction.apply
+
+
+class RoIPool(nn.Module):
+
+ def __init__(self, output_size, spatial_scale=1.0):
+ super(RoIPool, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+
+ def forward(self, input, rois):
+ return roi_pool(input, rois, self.output_size, self.spatial_scale)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/roiaware_pool3d.py b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..291b0e5a9b692492c7d7e495ea639c46042e2f18
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+import annotator.uniformer.mmcv as mmcv
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
+
+
+class RoIAwarePool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `PartA2 `_ for more
+ details.
+
+ Args:
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int, optional): The maximum number of points per
+ voxel. Default: 128.
+ mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
+ Default: 'max'.
+ """
+
+ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
+ super().__init__()
+
+ self.out_size = out_size
+ self.max_pts_per_voxel = max_pts_per_voxel
+ assert mode in ['max', 'avg']
+ pool_mapping = {'max': 0, 'avg': 1}
+ self.mode = pool_mapping[mode]
+
+ def forward(self, rois, pts, pts_feature):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
+ """
+
+ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
+ self.out_size,
+ self.max_pts_per_voxel, self.mode)
+
+
+class RoIAwarePool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
+ mode):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int): The maximum number of points per voxel.
+ Default: 128.
+ mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
+ pool).
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
+ pooled features.
+ """
+
+ if isinstance(out_size, int):
+ out_x = out_y = out_z = out_size
+ else:
+ assert len(out_size) == 3
+ assert mmcv.is_tuple_of(out_size, int)
+ out_x, out_y, out_z = out_size
+
+ num_rois = rois.shape[0]
+ num_channels = pts_feature.shape[-1]
+ num_pts = pts.shape[0]
+
+ pooled_features = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels))
+ argmax = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
+ pts_idx_of_voxels = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
+ dtype=torch.int)
+
+ ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
+ pts_idx_of_voxels, pooled_features,
+ mode)
+
+ ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
+ num_pts, num_channels)
+ return pooled_features
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ ret = ctx.roiaware_pool3d_for_backward
+ pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
+
+ grad_in = grad_out.new_zeros((num_pts, num_channels))
+ ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
+ grad_out.contiguous(), grad_in,
+ mode)
+
+ return None, None, grad_in, None, None, None
diff --git a/annotator/uniformer/mmcv/ops/roipoint_pool3d.py b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
@@ -0,0 +1,77 @@
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])
+
+
+class RoIPointPool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `Paper of PartA2 `_
+ for more details.
+
+ Args:
+ num_sampled_points (int, optional): Number of samples in each roi.
+ Default: 512.
+ """
+
+ def __init__(self, num_sampled_points=512):
+ super().__init__()
+ self.num_sampled_points = num_sampled_points
+
+ def forward(self, points, point_features, boxes3d):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
+ self.num_sampled_points)
+
+
+class RoIPointPool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+ num_sampled_points (int, optional): The num of sampled points.
+ Default: 512.
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ assert len(points.shape) == 3 and points.shape[2] == 3
+ batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
+ 1], point_features.shape[2]
+ pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
+ pooled_features = point_features.new_zeros(
+ (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
+ pooled_empty_flag = point_features.new_zeros(
+ (batch_size, boxes_num)).int()
+
+ ext_module.roipoint_pool3d_forward(points.contiguous(),
+ pooled_boxes3d.contiguous(),
+ point_features.contiguous(),
+ pooled_features, pooled_empty_flag)
+
+ return pooled_features, pooled_empty_flag
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ raise NotImplementedError
diff --git a/annotator/uniformer/mmcv/ops/saconv.py b/annotator/uniformer/mmcv/ops/saconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ee3978e097fca422805db4e31ae481006d7971
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/saconv.py
@@ -0,0 +1,145 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
+from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+
+
+@CONV_LAYERS.register_module(name='SAC')
+class SAConv2d(ConvAWS2d):
+ """SAC (Switchable Atrous Convolution)
+
+ This is an implementation of SAC in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf).
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ use_deform: If ``True``, replace convolution with deformable
+ convolution. Default: ``False``.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ use_deform=False):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.use_deform = use_deform
+ self.switch = nn.Conv2d(
+ self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
+ self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.pre_context = nn.Conv2d(
+ self.in_channels, self.in_channels, kernel_size=1, bias=True)
+ self.post_context = nn.Conv2d(
+ self.out_channels, self.out_channels, kernel_size=1, bias=True)
+ if self.use_deform:
+ self.offset_s = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.offset_l = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ constant_init(self.switch, 0, bias=1)
+ self.weight_diff.data.zero_()
+ constant_init(self.pre_context, 0)
+ constant_init(self.post_context, 0)
+ if self.use_deform:
+ constant_init(self.offset_s, 0)
+ constant_init(self.offset_l, 0)
+
+ def forward(self, x):
+ # pre-context
+ avg_x = F.adaptive_avg_pool2d(x, output_size=1)
+ avg_x = self.pre_context(avg_x)
+ avg_x = avg_x.expand_as(x)
+ x = x + avg_x
+ # switch
+ avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
+ avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
+ switch = self.switch(avg_x)
+ # sac
+ weight = self._get_weight(self.weight)
+ zero_bias = torch.zeros(
+ self.out_channels, device=weight.device, dtype=weight.dtype)
+
+ if self.use_deform:
+ offset = self.offset_s(avg_x)
+ out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_s = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_s = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_s = super()._conv_forward(x, weight)
+ ori_p = self.padding
+ ori_d = self.dilation
+ self.padding = tuple(3 * p for p in self.padding)
+ self.dilation = tuple(3 * d for d in self.dilation)
+ weight = weight + self.weight_diff
+ if self.use_deform:
+ offset = self.offset_l(avg_x)
+ out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_l = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_l = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_l = super()._conv_forward(x, weight)
+
+ out = switch * out_s + (1 - switch) * out_l
+ self.padding = ori_p
+ self.dilation = ori_d
+ # post-context
+ avg_x = F.adaptive_avg_pool2d(out, output_size=1)
+ avg_x = self.post_context(avg_x)
+ avg_x = avg_x.expand_as(out)
+ out = out + avg_x
+ return out
diff --git a/annotator/uniformer/mmcv/ops/scatter_points.py b/annotator/uniformer/mmcv/ops/scatter_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/scatter_points.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
+
+
+class _DynamicScatter(Function):
+
+ @staticmethod
+ def forward(ctx, feats, coors, reduce_type='max'):
+ """convert kitti points(N, >=3) to voxels.
+
+ Args:
+ feats (torch.Tensor): [N, C]. Points features to be reduced
+ into voxels.
+ coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
+ (specifically multi-dim voxel index) of each points.
+ reduce_type (str, optional): Reduce op. support 'max', 'sum' and
+ 'mean'. Default: 'max'.
+
+ Returns:
+ voxel_feats (torch.Tensor): [M, C]. Reduced features, input
+ features that shares the same voxel coordinates are reduced to
+ one row.
+ voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
+ """
+ results = ext_module.dynamic_point_to_voxel_forward(
+ feats, coors, reduce_type)
+ (voxel_feats, voxel_coors, point2voxel_map,
+ voxel_points_count) = results
+ ctx.reduce_type = reduce_type
+ ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
+ voxel_points_count)
+ ctx.mark_non_differentiable(voxel_coors)
+ return voxel_feats, voxel_coors
+
+ @staticmethod
+ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
+ (feats, voxel_feats, point2voxel_map,
+ voxel_points_count) = ctx.saved_tensors
+ grad_feats = torch.zeros_like(feats)
+ # TODO: whether to use index put or use cuda_backward
+ # To use index put, need point to voxel index
+ ext_module.dynamic_point_to_voxel_backward(
+ grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
+ point2voxel_map, voxel_points_count, ctx.reduce_type)
+ return grad_feats, None, None
+
+
+dynamic_scatter = _DynamicScatter.apply
+
+
+class DynamicScatter(nn.Module):
+ """Scatters points into voxels, used in the voxel encoder with dynamic
+ voxelization.
+
+ Note:
+ The CPU and GPU implementation get the same output, but have numerical
+ difference after summation and division (e.g., 5e-7).
+
+ Args:
+ voxel_size (list): list [x, y, z] size of three dimension.
+ point_cloud_range (list): The coordinate range of points, [x_min,
+ y_min, z_min, x_max, y_max, z_max].
+ average_points (bool): whether to use avg pooling to scatter points
+ into voxel.
+ """
+
+ def __init__(self, voxel_size, point_cloud_range, average_points: bool):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.average_points = average_points
+
+ def forward_single(self, points, coors):
+ """Scatters points into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ reduce = 'mean' if self.average_points else 'max'
+ return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
+
+ def forward(self, points, coors):
+ """Scatters points/features into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ if coors.size(-1) == 3:
+ return self.forward_single(points, coors)
+ else:
+ batch_size = coors[-1, 0] + 1
+ voxels, voxel_coors = [], []
+ for i in range(batch_size):
+ inds = torch.where(coors[:, 0] == i)
+ voxel, voxel_coor = self.forward_single(
+ points[inds], coors[inds][:, 1:])
+ coor_pad = nn.functional.pad(
+ voxel_coor, (1, 0), mode='constant', value=i)
+ voxel_coors.append(coor_pad)
+ voxels.append(voxel)
+ features = torch.cat(voxels, dim=0)
+ feature_coors = torch.cat(voxel_coors, dim=0)
+
+ return features, feature_coors
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', average_points=' + str(self.average_points)
+ s += ')'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/sync_bn.py b/annotator/uniformer/mmcv/ops/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b016fcbe860989c56cd1040034bcfa60e146d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/sync_bn.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.module import Module
+from torch.nn.parameter import Parameter
+
+from annotator.uniformer.mmcv.cnn import NORM_LAYERS
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
+ 'sync_bn_backward_param', 'sync_bn_backward_data'
+])
+
+
+class SyncBatchNormFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ return g.op(
+ 'mmcv::MMCVSyncBatchNorm',
+ input,
+ running_mean,
+ running_var,
+ weight,
+ bias,
+ momentum_f=momentum,
+ eps_f=eps,
+ group_i=group,
+ group_size_i=group_size,
+ stats_mode=stats_mode)
+
+ @staticmethod
+ def forward(self, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ self.momentum = momentum
+ self.eps = eps
+ self.group = group
+ self.group_size = group_size
+ self.stats_mode = stats_mode
+
+ assert isinstance(
+ input, (torch.HalfTensor, torch.FloatTensor,
+ torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+ f'only support Half or Float Tensor, but {input.type()}'
+ output = torch.zeros_like(input)
+ input3d = input.flatten(start_dim=2)
+ output3d = output.view_as(input3d)
+ num_channels = input3d.size(1)
+
+ # ensure mean/var/norm/std are initialized as zeros
+ # ``torch.empty()`` does not guarantee that
+ mean = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ var = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ norm = torch.zeros_like(
+ input3d, dtype=torch.float, device=input3d.device)
+ std = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+
+ batch_size = input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_forward_mean(input3d, mean)
+ batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
+ else:
+ # skip updating mean and leave it as zeros when the input is empty
+ batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
+
+ # synchronize mean and the batch flag
+ vec = torch.cat([mean, batch_flag])
+ if self.stats_mode == 'N':
+ vec *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(vec, group=self.group)
+ total_batch = vec[-1].detach()
+ mean = vec[:num_channels]
+
+ if self.stats_mode == 'default':
+ mean = mean / self.group_size
+ elif self.stats_mode == 'N':
+ mean = mean / total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # leave var as zeros when the input is empty
+ if batch_size > 0:
+ ext_module.sync_bn_forward_var(input3d, mean, var)
+
+ if self.stats_mode == 'N':
+ var *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(var, group=self.group)
+
+ if self.stats_mode == 'default':
+ var /= self.group_size
+ elif self.stats_mode == 'N':
+ var /= total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # if the total batch size over all the ranks is zero,
+ # we should not update the statistics in the current batch
+ update_flag = total_batch.clamp(max=1)
+ momentum = update_flag * self.momentum
+ ext_module.sync_bn_forward_output(
+ input3d,
+ mean,
+ var,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ norm,
+ std,
+ output3d,
+ eps=self.eps,
+ momentum=momentum,
+ group_size=self.group_size)
+ self.save_for_backward(norm, std, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(self, grad_output):
+ norm, std, weight = self.saved_tensors
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(weight)
+ grad_input = torch.zeros_like(grad_output)
+ grad_output3d = grad_output.flatten(start_dim=2)
+ grad_input3d = grad_input.view_as(grad_output3d)
+
+ batch_size = grad_input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
+ grad_bias)
+
+ # all reduce
+ if self.group_size > 1:
+ dist.all_reduce(grad_weight, group=self.group)
+ dist.all_reduce(grad_bias, group=self.group)
+ grad_weight /= self.group_size
+ grad_bias /= self.group_size
+
+ if batch_size > 0:
+ ext_module.sync_bn_backward_data(grad_output3d, weight,
+ grad_weight, grad_bias, norm, std,
+ grad_input3d)
+
+ return grad_input, None, None, grad_weight, grad_bias, \
+ None, None, None, None, None
+
+
+@NORM_LAYERS.register_module(name='MMSyncBN')
+class SyncBatchNorm(Module):
+ """Synchronized Batch Normalization.
+
+ Args:
+ num_features (int): number of features/chennels in input tensor
+ eps (float, optional): a value added to the denominator for numerical
+ stability. Defaults to 1e-5.
+ momentum (float, optional): the value used for the running_mean and
+ running_var computation. Defaults to 0.1.
+ affine (bool, optional): whether to use learnable affine parameters.
+ Defaults to True.
+ track_running_stats (bool, optional): whether to track the running
+ mean and variance during training. When set to False, this
+ module does not track such statistics, and initializes statistics
+ buffers ``running_mean`` and ``running_var`` as ``None``. When
+ these buffers are ``None``, this module always uses batch
+ statistics in both training and eval modes. Defaults to True.
+ group (int, optional): synchronization of stats happen within
+ each process group individually. By default it is synchronization
+ across the whole world. Defaults to None.
+ stats_mode (str, optional): The statistical mode. Available options
+ includes ``'default'`` and ``'N'``. Defaults to 'default'.
+ When ``stats_mode=='default'``, it computes the overall statistics
+ using those from each worker with equal weight, i.e., the
+ statistics are synchronized and simply divied by ``group``. This
+ mode will produce inaccurate statistics when empty tensors occur.
+ When ``stats_mode=='N'``, it compute the overall statistics using
+ the total number of batches in each worker ignoring the number of
+ group, i.e., the statistics are synchronized and then divied by
+ the total batch ``N``. This mode is beneficial when empty tensors
+ occur during training, as it average the total mean by the real
+ number of batch.
+ """
+
+ def __init__(self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ group=None,
+ stats_mode='default'):
+ super(SyncBatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ group = dist.group.WORLD if group is None else group
+ self.group = group
+ self.group_size = dist.get_world_size(group)
+ assert stats_mode in ['default', 'N'], \
+ f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
+ self.stats_mode = stats_mode
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked',
+ torch.tensor(0, dtype=torch.long))
+ else:
+ self.register_buffer('running_mean', None)
+ self.register_buffer('running_var', None)
+ self.register_buffer('num_batches_tracked', None)
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ self.weight.data.uniform_() # pytorch use ones_()
+ self.bias.data.zero_()
+
+ def forward(self, input):
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input, got {input.dim()}D input')
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(
+ self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ if self.training or not self.track_running_stats:
+ return SyncBatchNormFunction.apply(
+ input, self.running_mean, self.running_var, self.weight,
+ self.bias, exponential_average_factor, self.eps, self.group,
+ self.group_size, self.stats_mode)
+ else:
+ return F.batch_norm(input, self.running_mean, self.running_var,
+ self.weight, self.bias, False,
+ exponential_average_factor, self.eps)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'({self.num_features}, '
+ s += f'eps={self.eps}, '
+ s += f'momentum={self.momentum}, '
+ s += f'affine={self.affine}, '
+ s += f'track_running_stats={self.track_running_stats}, '
+ s += f'group_size={self.group_size},'
+ s += f'stats_mode={self.stats_mode})'
+ return s
diff --git a/annotator/uniformer/mmcv/ops/three_interpolate.py b/annotator/uniformer/mmcv/ops/three_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/three_interpolate.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
+
+
+class ThreeInterpolate(Function):
+ """Performs weighted linear interpolation on 3 features.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
+ weight: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, M) Features descriptors to be
+ interpolated
+ indices (Tensor): (B, n, 3) index three nearest neighbors
+ of the target features in features
+ weight (Tensor): (B, n, 3) weights of interpolation
+
+ Returns:
+ Tensor: (B, C, N) tensor of the interpolated features
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+ assert weight.is_contiguous()
+
+ B, c, m = features.size()
+ n = indices.size(1)
+ ctx.three_interpolate_for_backward = (indices, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+
+ ext_module.three_interpolate_forward(
+ features, indices, weight, output, b=B, c=c, m=m, n=n)
+ return output
+
+ @staticmethod
+ def backward(
+ ctx, grad_out: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, N) tensor with gradients of outputs
+
+ Returns:
+ Tensor: (B, C, M) tensor with gradients of features
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+
+ grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+ grad_out_data = grad_out.data.contiguous()
+
+ ext_module.three_interpolate_backward(
+ grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
+ return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
diff --git a/annotator/uniformer/mmcv/ops/three_nn.py b/annotator/uniformer/mmcv/ops/three_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/three_nn.py
@@ -0,0 +1,51 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
+
+
+class ThreeNN(Function):
+ """Find the top-3 nearest neighbors of the target set from the source set.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, target: torch.Tensor,
+ source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ target (Tensor): shape (B, N, 3), points set that needs to
+ find the nearest neighbors.
+ source (Tensor): shape (B, M, 3), points set that is used
+ to find the nearest neighbors of points in target set.
+
+ Returns:
+ Tensor: shape (B, N, 3), L2 distance of each point in target
+ set to their corresponding nearest neighbors.
+ """
+ target = target.contiguous()
+ source = source.contiguous()
+
+ B, N, _ = target.size()
+ m = source.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+
+ ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+
+
+three_nn = ThreeNN.apply
diff --git a/annotator/uniformer/mmcv/ops/tin_shift.py b/annotator/uniformer/mmcv/ops/tin_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/tin_shift.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code reference from "Temporal Interlacing Network"
+# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
+# Hao Shao, Shengju Qian, Yu Liu
+# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['tin_shift_forward', 'tin_shift_backward'])
+
+
+class TINShiftFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, shift):
+ C = input.size(2)
+ num_segments = shift.size(1)
+ if C // num_segments <= 0 or C % num_segments != 0:
+ raise ValueError('C should be a multiple of num_segments, '
+ f'but got C={C} and num_segments={num_segments}.')
+
+ ctx.save_for_backward(shift)
+
+ out = torch.zeros_like(input)
+ ext_module.tin_shift_forward(input, shift, out)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ shift = ctx.saved_tensors[0]
+ data_grad_input = grad_output.new(*grad_output.size()).zero_()
+ shift_grad_input = shift.new(*shift.size()).zero_()
+ ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
+
+ return data_grad_input, shift_grad_input
+
+
+tin_shift = TINShiftFunction.apply
+
+
+class TINShift(nn.Module):
+ """Temporal Interlace Shift.
+
+ Temporal Interlace shift is a differentiable temporal-wise frame shifting
+ which is proposed in "Temporal Interlacing Network"
+
+ Please refer to https://arxiv.org/abs/2001.06499 for more details.
+ Code is modified from https://github.com/mit-han-lab/temporal-shift-module
+ """
+
+ def forward(self, input, shift):
+ """Perform temporal interlace shift.
+
+ Args:
+ input (Tensor): Feature map with shape [N, num_segments, C, H * W].
+ shift (Tensor): Shift tensor with shape [N, num_segments].
+
+ Returns:
+ Feature map after temporal interlace shift.
+ """
+ return tin_shift(input, shift)
diff --git a/annotator/uniformer/mmcv/ops/upfirdn2d.py b/annotator/uniformer/mmcv/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8bb2c3c949eed38a6465ed369fa881538dca010
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/upfirdn2d.py
@@ -0,0 +1,330 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+from annotator.uniformer.mmcv.utils import to_2tuple
+from ..utils import ext_loader
+
+upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
+ in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ up_x=down_x,
+ up_y=down_y,
+ down_x=up_x,
+ down_y=up_y,
+ pad_x0=g_pad_x0,
+ pad_x1=g_pad_x1,
+ pad_y0=g_pad_y0,
+ pad_y1=g_pad_y1)
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
+ in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
+ ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ up_x=ctx.up_x,
+ up_y=ctx.up_y,
+ down_x=ctx.down_x,
+ down_y=ctx.down_y,
+ pad_x0=ctx.pad_x0,
+ pad_x1=ctx.pad_x1,
+ pad_y0=ctx.pad_y0,
+ pad_y1=ctx.pad_y1)
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
+ ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(
+ input,
+ kernel,
+ up_x=up_x,
+ up_y=up_y,
+ down_x=down_x,
+ down_y=down_y,
+ pad_x0=pad_x0,
+ pad_x1=pad_x1,
+ pad_y0=pad_y0,
+ pad_y1=pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ """UpFRIDn for 2d features.
+
+ UpFIRDn is short for upsample, apply FIR filter and downsample. More
+ details can be found in:
+ https://www.mathworks.com/help/signal/ref/upfirdn.html
+
+ Args:
+ input (Tensor): Tensor with shape of (n, c, h, w).
+ kernel (Tensor): Filter kernel.
+ up (int | tuple[int], optional): Upsampling factor. If given a number,
+ we will use this factor for the both height and width side.
+ Defaults to 1.
+ down (int | tuple[int], optional): Downsampling factor. If given a
+ number, we will use this factor for the both height and width side.
+ Defaults to 1.
+ pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
+ (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
+
+ Returns:
+ Tensor: Tensor after UpFIRDn.
+ """
+ if input.device.type == 'cpu':
+ if len(pad) == 2:
+ pad = (pad[0], pad[1], pad[0], pad[1])
+
+ up = to_2tuple(up)
+
+ down = to_2tuple(down)
+
+ out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
+ pad[0], pad[1], pad[2], pad[3])
+ else:
+ _up = to_2tuple(up)
+
+ _down = to_2tuple(down)
+
+ if len(pad) == 4:
+ _pad = pad
+ elif len(pad) == 2:
+ _pad = (pad[0], pad[1], pad[0], pad[1])
+
+ out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
+ pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out,
+ [0, 0,
+ max(pad_x0, 0),
+ max(pad_x1, 0),
+ max(pad_y0, 0),
+ max(pad_y1, 0)])
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/annotator/uniformer/mmcv/ops/voxelize.py b/annotator/uniformer/mmcv/ops/voxelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/voxelize.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
+
+
+class _Voxelization(Function):
+
+ @staticmethod
+ def forward(ctx,
+ points,
+ voxel_size,
+ coors_range,
+ max_points=35,
+ max_voxels=20000):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Args:
+ points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
+ and points[:, 3:] contain other information like reflectivity.
+ voxel_size (tuple or float): The size of voxel with the shape of
+ [3].
+ coors_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_points (int, optional): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize. Default: 35.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+
+ Returns:
+ voxels_out (torch.Tensor): Output voxels with the shape of [M,
+ max_points, ndim]. Only contain points and returned when
+ max_points != -1.
+ coors_out (torch.Tensor): Output coordinates with the shape of
+ [M, 3].
+ num_points_per_voxel_out (torch.Tensor): Num points per voxel with
+ the shape of [M]. Only returned when max_points != -1.
+ """
+ if max_points == -1 or max_voxels == -1:
+ coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
+ ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
+ coors_range, 3)
+ return coors
+ else:
+ voxels = points.new_zeros(
+ size=(max_voxels, max_points, points.size(1)))
+ coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
+ num_points_per_voxel = points.new_zeros(
+ size=(max_voxels, ), dtype=torch.int)
+ voxel_num = ext_module.hard_voxelize_forward(
+ points, voxels, coors, num_points_per_voxel, voxel_size,
+ coors_range, max_points, max_voxels, 3)
+ # select the valid voxels
+ voxels_out = voxels[:voxel_num]
+ coors_out = coors[:voxel_num]
+ num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
+ return voxels_out, coors_out, num_points_per_voxel_out
+
+
+voxelization = _Voxelization.apply
+
+
+class Voxelization(nn.Module):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Please refer to `PVCNN `_ for more
+ details.
+
+ Args:
+ voxel_size (tuple or float): The size of voxel with the shape of [3].
+ point_cloud_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_num_points (int): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+ """
+
+ def __init__(self,
+ voxel_size,
+ point_cloud_range,
+ max_num_points,
+ max_voxels=20000):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.max_num_points = max_num_points
+ if isinstance(max_voxels, tuple):
+ self.max_voxels = max_voxels
+ else:
+ self.max_voxels = _pair(max_voxels)
+
+ point_cloud_range = torch.tensor(
+ point_cloud_range, dtype=torch.float32)
+ voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
+ grid_size = (point_cloud_range[3:] -
+ point_cloud_range[:3]) / voxel_size
+ grid_size = torch.round(grid_size).long()
+ input_feat_shape = grid_size[:2]
+ self.grid_size = grid_size
+ # the origin shape is as [x-len, y-len, z-len]
+ # [w, h, d] -> [d, h, w]
+ self.pcd_shape = [*input_feat_shape, 1][::-1]
+
+ def forward(self, input):
+ if self.training:
+ max_voxels = self.max_voxels[0]
+ else:
+ max_voxels = self.max_voxels[1]
+
+ return voxelization(input, self.voxel_size, self.point_cloud_range,
+ self.max_num_points, max_voxels)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', max_num_points=' + str(self.max_num_points)
+ s += ', max_voxels=' + str(self.max_voxels)
+ s += ')'
+ return s
diff --git a/annotator/uniformer/mmcv/parallel/__init__.py b/annotator/uniformer/mmcv/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .collate import collate
+from .data_container import DataContainer
+from .data_parallel import MMDataParallel
+from .distributed import MMDistributedDataParallel
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import is_module_wrapper
+
+__all__ = [
+ 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
+ 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
+]
diff --git a/annotator/uniformer/mmcv/parallel/_functions.py b/annotator/uniformer/mmcv/parallel/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/_functions.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import _get_stream
+
+
+def scatter(input, devices, streams=None):
+ """Scatters tensor across multiple GPUs."""
+ if streams is None:
+ streams = [None] * len(devices)
+
+ if isinstance(input, list):
+ chunk_size = (len(input) - 1) // len(devices) + 1
+ outputs = [
+ scatter(input[i], [devices[i // chunk_size]],
+ [streams[i // chunk_size]]) for i in range(len(input))
+ ]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ # TODO: copy to a pinned buffer first (if copying from CPU)
+ stream = streams[0] if output.numel() > 0 else None
+ if devices != [-1]:
+ with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+ output = output.cuda(devices[0], non_blocking=True)
+ else:
+ # unsqueeze the first dimension thus the tensor's shape is the
+ # same as those scattered with GPU.
+ output = output.unsqueeze(0)
+ return output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+def synchronize_stream(output, devices, streams):
+ if isinstance(output, list):
+ chunk_size = len(output) // len(devices)
+ for i in range(len(devices)):
+ for j in range(chunk_size):
+ synchronize_stream(output[i * chunk_size + j], [devices[i]],
+ [streams[i]])
+ elif isinstance(output, torch.Tensor):
+ if output.numel() != 0:
+ with torch.cuda.device(devices[0]):
+ main_stream = torch.cuda.current_stream()
+ main_stream.wait_stream(streams[0])
+ output.record_stream(main_stream)
+ else:
+ raise Exception(f'Unknown type {type(output)}.')
+
+
+def get_input_device(input):
+ if isinstance(input, list):
+ for item in input:
+ input_device = get_input_device(item)
+ if input_device != -1:
+ return input_device
+ return -1
+ elif isinstance(input, torch.Tensor):
+ return input.get_device() if input.is_cuda else -1
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+ @staticmethod
+ def forward(target_gpus, input):
+ input_device = get_input_device(input)
+ streams = None
+ if input_device == -1 and target_gpus != [-1]:
+ # Perform CPU to GPU copies in a background stream
+ streams = [_get_stream(device) for device in target_gpus]
+
+ outputs = scatter(input, target_gpus, streams)
+ # Synchronize with the copy stream
+ if streams is not None:
+ synchronize_stream(outputs, target_gpus, streams)
+
+ return tuple(outputs)
diff --git a/annotator/uniformer/mmcv/parallel/collate.py b/annotator/uniformer/mmcv/parallel/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/collate.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+
+from .data_container import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+ """Puts each data field into a tensor/DataContainer with outer dimension
+ batch size.
+
+ Extend default_collate to add support for
+ :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
+
+ 1. cpu_only = True, e.g., meta data
+ 2. cpu_only = False, stack = True, e.g., images tensors
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
+ """
+
+ if not isinstance(batch, Sequence):
+ raise TypeError(f'{batch.dtype} is not supported.')
+
+ if isinstance(batch[0], DataContainer):
+ stacked = []
+ if batch[0].cpu_only:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
+ elif batch[0].stack:
+ for i in range(0, len(batch), samples_per_gpu):
+ assert isinstance(batch[i].data, torch.Tensor)
+
+ if batch[i].pad_dims is not None:
+ ndim = batch[i].dim()
+ assert ndim > batch[i].pad_dims
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = batch[i].size(-dim)
+ for sample in batch[i:i + samples_per_gpu]:
+ for dim in range(0, ndim - batch[i].pad_dims):
+ assert batch[i].size(dim) == sample.size(dim)
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = max(max_shape[dim - 1],
+ sample.size(-dim))
+ padded_samples = []
+ for sample in batch[i:i + samples_per_gpu]:
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ pad[2 * dim -
+ 1] = max_shape[dim - 1] - sample.size(-dim)
+ padded_samples.append(
+ F.pad(
+ sample.data, pad, value=sample.padding_value))
+ stacked.append(default_collate(padded_samples))
+ elif batch[i].pad_dims is None:
+ stacked.append(
+ default_collate([
+ sample.data
+ for sample in batch[i:i + samples_per_gpu]
+ ]))
+ else:
+ raise ValueError(
+ 'pad_dims should be either None or integers (1-3)')
+
+ else:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+ elif isinstance(batch[0], Sequence):
+ transposed = zip(*batch)
+ return [collate(samples, samples_per_gpu) for samples in transposed]
+ elif isinstance(batch[0], Mapping):
+ return {
+ key: collate([d[key] for d in batch], samples_per_gpu)
+ for key in batch[0]
+ }
+ else:
+ return default_collate(batch)
diff --git a/annotator/uniformer/mmcv/parallel/data_container.py b/annotator/uniformer/mmcv/parallel/data_container.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/data_container.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import torch
+
+
+def assert_tensor_type(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if not isinstance(args[0].data, torch.Tensor):
+ raise AttributeError(
+ f'{args[0].__class__.__name__} has no attribute '
+ f'{func.__name__} for type {args[0].datatype}')
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class DataContainer:
+ """A container for any type of objects.
+
+ Typically tensors will be stacked in the collate function and sliced along
+ some dimension in the scatter function. This behavior has some limitations.
+ 1. All tensors have to be the same size.
+ 2. Types are limited (numpy array or Tensor).
+
+ We design `DataContainer` and `MMDataParallel` to overcome these
+ limitations. The behavior can be either of the following.
+
+ - copy to GPU, pad all tensors to the same size and stack them
+ - copy to GPU without stacking
+ - leave the objects as is and pass it to the model
+ - pad_dims specifies the number of last few dimensions to do padding
+ """
+
+ def __init__(self,
+ data,
+ stack=False,
+ padding_value=0,
+ cpu_only=False,
+ pad_dims=2):
+ self._data = data
+ self._cpu_only = cpu_only
+ self._stack = stack
+ self._padding_value = padding_value
+ assert pad_dims in [None, 1, 2, 3]
+ self._pad_dims = pad_dims
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({repr(self.data)})'
+
+ def __len__(self):
+ return len(self._data)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def datatype(self):
+ if isinstance(self.data, torch.Tensor):
+ return self.data.type()
+ else:
+ return type(self.data)
+
+ @property
+ def cpu_only(self):
+ return self._cpu_only
+
+ @property
+ def stack(self):
+ return self._stack
+
+ @property
+ def padding_value(self):
+ return self._padding_value
+
+ @property
+ def pad_dims(self):
+ return self._pad_dims
+
+ @assert_tensor_type
+ def size(self, *args, **kwargs):
+ return self.data.size(*args, **kwargs)
+
+ @assert_tensor_type
+ def dim(self):
+ return self.data.dim()
diff --git a/annotator/uniformer/mmcv/parallel/data_parallel.py b/annotator/uniformer/mmcv/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/data_parallel.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import chain
+
+from torch.nn.parallel import DataParallel
+
+from .scatter_gather import scatter_kwargs
+
+
+class MMDataParallel(DataParallel):
+ """The DataParallel module that supports DataContainer.
+
+ MMDataParallel has two main differences with PyTorch DataParallel:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data during both GPU and CPU inference.
+ - It implement two more APIs ``train_step()`` and ``val_step()``.
+
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ device_ids (list[int]): Device IDS of modules to be scattered to.
+ Defaults to None when GPU is not available.
+ output_device (str | int): Device ID for output. Defaults to None.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+
+ def __init__(self, *args, dim=0, **kwargs):
+ super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
+ self.dim = dim
+
+ def forward(self, *inputs, **kwargs):
+ """Override the original forward function.
+
+ The main difference lies in the CPU inference where the data in
+ :class:`DataContainers` will still be gathered.
+ """
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module(*inputs[0], **kwargs[0])
+ else:
+ return super().forward(*inputs, **kwargs)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ 'instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ def val_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.val_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ ' instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.val_step(*inputs[0], **kwargs[0])
diff --git a/annotator/uniformer/mmcv/parallel/distributed.py b/annotator/uniformer/mmcv/parallel/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e4c27903db58a54d37ea1ed9ec0104098b486f2
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/distributed.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel.distributed import (DistributedDataParallel,
+ _find_tensors)
+
+from annotator.uniformer.mmcv import print_log
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .scatter_gather import scatter_kwargs
+
+
+class MMDistributedDataParallel(DistributedDataParallel):
+ """The DDP module that supports DataContainer.
+
+ MMDDP has two main differences with PyTorch DDP:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data.
+ - It implement two APIs ``train_step()`` and ``val_step()``.
+ """
+
+ def to_kwargs(self, inputs, kwargs, device_id):
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+ # to move all tensors to device_id
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ """train_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.train_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.train_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ """val_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.val_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.val_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
diff --git a/annotator/uniformer/mmcv/parallel/distributed_deprecated.py b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..676937a2085d4da20fa87923041a200fca6214eb
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter_kwargs
+
+
+@MODULE_WRAPPERS.register_module()
+class MMDistributedDataParallel(nn.Module):
+
+ def __init__(self,
+ module,
+ dim=0,
+ broadcast_buffers=True,
+ bucket_cap_mb=25):
+ super(MMDistributedDataParallel, self).__init__()
+ self.module = module
+ self.dim = dim
+ self.broadcast_buffers = broadcast_buffers
+
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+ self._sync_params()
+
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ for tensors in _take_tensors(tensors, buffer_size):
+ flat_tensors = _flatten_dense_tensors(tensors)
+ dist.broadcast(flat_tensors, 0)
+ for tensor, synced in zip(
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
+ tensor.copy_(synced)
+
+ def _sync_params(self):
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states,
+ self.broadcast_bucket_size)
+ if self.broadcast_buffers:
+ if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) < digit_version('1.0')):
+ buffers = [b.data for b in self.module._all_buffers()]
+ else:
+ buffers = [b.data for b in self.module.buffers()]
+ if len(buffers) > 0:
+ self._dist_broadcast_coalesced(buffers,
+ self.broadcast_bucket_size)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def forward(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
+
+ def train_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ return output
diff --git a/annotator/uniformer/mmcv/parallel/registry.py b/annotator/uniformer/mmcv/parallel/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..a204a07fba10e614223f090d1a57cf9c4d74d4a1
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/registry.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from annotator.uniformer.mmcv.utils import Registry
+
+MODULE_WRAPPERS = Registry('module wrapper')
+MODULE_WRAPPERS.register_module(module=DataParallel)
+MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
diff --git a/annotator/uniformer/mmcv/parallel/scatter_gather.py b/annotator/uniformer/mmcv/parallel/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+
+from ._functions import Scatter
+from .data_container import DataContainer
+
+
+def scatter(inputs, target_gpus, dim=0):
+ """Scatter inputs to target gpus.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_gpus != [-1]:
+ return OrigScatter.apply(target_gpus, None, dim, obj)
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_gpus, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_gpus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_gpus]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/annotator/uniformer/mmcv/parallel/utils.py b/annotator/uniformer/mmcv/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .registry import MODULE_WRAPPERS
+
+
+def is_module_wrapper(module):
+ """Check if a module is a module wrapper.
+
+ The following 3 modules in MMCV (and their subclasses) are regarded as
+ module wrappers: DataParallel, DistributedDataParallel,
+ MMDistributedDataParallel (the deprecated version). You may add you own
+ module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: True if the input module is a module wrapper.
+ """
+ module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
+ return isinstance(module, module_wrappers)
diff --git a/annotator/uniformer/mmcv/runner/__init__.py b/annotator/uniformer/mmcv/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_module import BaseModule, ModuleList, Sequential
+from .base_runner import BaseRunner
+from .builder import RUNNERS, build_runner
+from .checkpoint import (CheckpointLoader, _load_checkpoint,
+ _load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict, save_checkpoint, weights_to_cpu)
+from .default_constructor import DefaultRunnerConstructor
+from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
+ init_dist, master_only)
+from .epoch_based_runner import EpochBasedRunner, Runner
+from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
+from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, Hook, IterTimerHook,
+ LoggerHook, LrUpdaterHook, MlflowLoggerHook,
+ NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
+ SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
+ WandbLoggerHook)
+from .iter_based_runner import IterBasedRunner, IterLoader
+from .log_buffer import LogBuffer
+from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
+ DefaultOptimizerConstructor, build_optimizer,
+ build_optimizer_constructor)
+from .priority import Priority, get_priority
+from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
+
+__all__ = [
+ 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+ 'ModuleList', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
+]
diff --git a/annotator/uniformer/mmcv/runner/base_module.py b/annotator/uniformer/mmcv/runner/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..617fad9bb89f10a9a0911d962dfb3bc8f3a3628c
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/base_module.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.runner.dist_utils import master_only
+from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log
+
+
+class BaseModule(nn.Module, metaclass=ABCMeta):
+ """Base module for all modules in openmmlab.
+
+ ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
+ functionality of parameter initialization. Compared with
+ ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
+
+ - ``init_cfg``: the config to control the initialization.
+ - ``init_weights``: The function of parameter
+ initialization and recording initialization
+ information.
+ - ``_params_init_info``: Used to track the parameter
+ initialization information. This attribute only
+ exists during executing the ``init_weights``.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, init_cfg=None):
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
+
+ # NOTE init_cfg can be defined in different levels, but init_cfg
+ # in low levels has a higher priority.
+
+ super(BaseModule, self).__init__()
+ # define default value of init_cfg instead of hard code
+ # in init_weights() function
+ self._is_init = False
+
+ self.init_cfg = copy.deepcopy(init_cfg)
+
+ # Backward compatibility in derived classes
+ # if pretrained is not None:
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
+ # key, please consider using init_cfg')
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+
+ @property
+ def is_init(self):
+ return self._is_init
+
+ def init_weights(self):
+ """Initialize the weights."""
+
+ is_top_level_module = False
+ # check if it is top-level module
+ if not hasattr(self, '_params_init_info'):
+ # The `_params_init_info` is used to record the initialization
+ # information of the parameters
+ # the key should be the obj:`nn.Parameter` of model and the value
+ # should be a dict containing
+ # - init_info (str): The string that describes the initialization.
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
+ # which indicates whether the parameter has been modified.
+ # this attribute would be deleted after all parameters
+ # is initialized.
+ self._params_init_info = defaultdict(dict)
+ is_top_level_module = True
+
+ # Initialize the `_params_init_info`,
+ # When detecting the `tmp_mean_value` of
+ # the corresponding parameter is changed, update related
+ # initialization information
+ for name, param in self.named_parameters():
+ self._params_init_info[param][
+ 'init_info'] = f'The value is the same before and ' \
+ f'after calling `init_weights` ' \
+ f'of {self.__class__.__name__} '
+ self._params_init_info[param][
+ 'tmp_mean_value'] = param.data.mean()
+
+ # pass `params_init_info` to all submodules
+ # All submodules share the same `params_init_info`,
+ # so it will be updated when parameters are
+ # modified at any level of the model.
+ for sub_module in self.modules():
+ sub_module._params_init_info = self._params_init_info
+
+ # Get the initialized logger, if not exist,
+ # create a logger named `mmcv`
+ logger_names = list(logger_initialized.keys())
+ logger_name = logger_names[0] if logger_names else 'mmcv'
+
+ from ..cnn import initialize
+ from ..cnn.utils.weight_init import update_init_info
+ module_name = self.__class__.__name__
+ if not self._is_init:
+ if self.init_cfg:
+ print_log(
+ f'initialize {module_name} with init_cfg {self.init_cfg}',
+ logger=logger_name)
+ initialize(self, self.init_cfg)
+ if isinstance(self.init_cfg, dict):
+ # prevent the parameters of
+ # the pre-trained model
+ # from being overwritten by
+ # the `init_weights`
+ if self.init_cfg['type'] == 'Pretrained':
+ return
+
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights()
+ # users may overload the `init_weights`
+ update_init_info(
+ m,
+ init_info=f'Initialized by '
+ f'user-defined `init_weights`'
+ f' in {m.__class__.__name__} ')
+
+ self._is_init = True
+ else:
+ warnings.warn(f'init_weights of {self.__class__.__name__} has '
+ f'been called more than once.')
+
+ if is_top_level_module:
+ self._dump_init_info(logger_name)
+
+ for sub_module in self.modules():
+ del sub_module._params_init_info
+
+ @master_only
+ def _dump_init_info(self, logger_name):
+ """Dump the initialization information to a file named
+ `initialization.log.json` in workdir.
+
+ Args:
+ logger_name (str): The name of logger.
+ """
+
+ logger = get_logger(logger_name)
+
+ with_file_handler = False
+ # dump the information to the logger file if there is a `FileHandler`
+ for handler in logger.handlers:
+ if isinstance(handler, FileHandler):
+ handler.stream.write(
+ 'Name of parameter - Initialization information\n')
+ for name, param in self.named_parameters():
+ handler.stream.write(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n")
+ handler.stream.flush()
+ with_file_handler = True
+ if not with_file_handler:
+ for name, param in self.named_parameters():
+ print_log(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n ",
+ logger=logger_name)
+
+ def __repr__(self):
+ s = super().__repr__()
+ if self.init_cfg:
+ s += f'\ninit_cfg={self.init_cfg}'
+ return s
+
+
+class Sequential(BaseModule, nn.Sequential):
+ """Sequential module in openmmlab.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, *args, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.Sequential.__init__(self, *args)
+
+
+class ModuleList(BaseModule, nn.ModuleList):
+ """ModuleList in openmmlab.
+
+ Args:
+ modules (iterable, optional): an iterable of modules to add.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, modules=None, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.ModuleList.__init__(self, modules)
diff --git a/annotator/uniformer/mmcv/runner/base_runner.py b/annotator/uniformer/mmcv/runner/base_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4928db0a73b56fe0218a4bf66ec4ffa082d31ccc
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/base_runner.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.uniformer.mmcv as mmcv
+from ..parallel import is_module_wrapper
+from .checkpoint import load_checkpoint
+from .dist_utils import get_dist_info
+from .hooks import HOOKS, Hook
+from .log_buffer import LogBuffer
+from .priority import Priority, get_priority
+from .utils import get_time_str
+
+
+class BaseRunner(metaclass=ABCMeta):
+ """The base class of Runner, a training helper for PyTorch.
+
+ All subclasses should implement the following APIs:
+
+ - ``run()``
+ - ``train()``
+ - ``val()``
+ - ``save_checkpoint()``
+
+ Args:
+ model (:obj:`torch.nn.Module`): The model to be run.
+ batch_processor (callable): A callable method that process a data
+ batch. The interface of this method should be
+ `batch_processor(model, data, train_mode) -> dict`
+ optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+ optimizer (in most cases) or a dict of optimizers (in models that
+ requires more than one optimizer, e.g., GAN).
+ work_dir (str, optional): The working directory to save checkpoints
+ and logs. Defaults to None.
+ logger (:obj:`logging.Logger`): Logger used during training.
+ Defaults to None. (The default value is just for backward
+ compatibility)
+ meta (dict | None): A dict records some import information such as
+ environment info and seed, which will be logged in logger hook.
+ Defaults to None.
+ max_epochs (int, optional): Total training epochs.
+ max_iters (int, optional): Total training iterations.
+ """
+
+ def __init__(self,
+ model,
+ batch_processor=None,
+ optimizer=None,
+ work_dir=None,
+ logger=None,
+ meta=None,
+ max_iters=None,
+ max_epochs=None):
+ if batch_processor is not None:
+ if not callable(batch_processor):
+ raise TypeError('batch_processor must be callable, '
+ f'but got {type(batch_processor)}')
+ warnings.warn('batch_processor is deprecated, please implement '
+ 'train_step() and val_step() in the model instead.')
+ # raise an error is `batch_processor` is not None and
+ # `model.train_step()` exists.
+ if is_module_wrapper(model):
+ _model = model.module
+ else:
+ _model = model
+ if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+ raise RuntimeError(
+ 'batch_processor and model.train_step()/model.val_step() '
+ 'cannot be both available.')
+ else:
+ assert hasattr(model, 'train_step')
+
+ # check the type of `optimizer`
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ if not isinstance(optim, Optimizer):
+ raise TypeError(
+ f'optimizer must be a dict of torch.optim.Optimizers, '
+ f'but optimizer["{name}"] is a {type(optim)}')
+ elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+ raise TypeError(
+ f'optimizer must be a torch.optim.Optimizer object '
+ f'or dict or None, but got {type(optimizer)}')
+
+ # check the type of `logger`
+ if not isinstance(logger, logging.Logger):
+ raise TypeError(f'logger must be a logging.Logger object, '
+ f'but got {type(logger)}')
+
+ # check the type of `meta`
+ if meta is not None and not isinstance(meta, dict):
+ raise TypeError(
+ f'meta must be a dict or None, but got {type(meta)}')
+
+ self.model = model
+ self.batch_processor = batch_processor
+ self.optimizer = optimizer
+ self.logger = logger
+ self.meta = meta
+ # create work_dir
+ if mmcv.is_str(work_dir):
+ self.work_dir = osp.abspath(work_dir)
+ mmcv.mkdir_or_exist(self.work_dir)
+ elif work_dir is None:
+ self.work_dir = None
+ else:
+ raise TypeError('"work_dir" must be a str or None')
+
+ # get model name from the model class
+ if hasattr(self.model, 'module'):
+ self._model_name = self.model.module.__class__.__name__
+ else:
+ self._model_name = self.model.__class__.__name__
+
+ self._rank, self._world_size = get_dist_info()
+ self.timestamp = get_time_str()
+ self.mode = None
+ self._hooks = []
+ self._epoch = 0
+ self._iter = 0
+ self._inner_iter = 0
+
+ if max_epochs is not None and max_iters is not None:
+ raise ValueError(
+ 'Only one of `max_epochs` or `max_iters` can be set.')
+
+ self._max_epochs = max_epochs
+ self._max_iters = max_iters
+ # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+ self.log_buffer = LogBuffer()
+
+ @property
+ def model_name(self):
+ """str: Name of the model, usually the module class name."""
+ return self._model_name
+
+ @property
+ def rank(self):
+ """int: Rank of current process. (distributed training)"""
+ return self._rank
+
+ @property
+ def world_size(self):
+ """int: Number of processes participating in the job.
+ (distributed training)"""
+ return self._world_size
+
+ @property
+ def hooks(self):
+ """list[:obj:`Hook`]: A list of registered hooks."""
+ return self._hooks
+
+ @property
+ def epoch(self):
+ """int: Current epoch."""
+ return self._epoch
+
+ @property
+ def iter(self):
+ """int: Current iteration."""
+ return self._iter
+
+ @property
+ def inner_iter(self):
+ """int: Iteration in an epoch."""
+ return self._inner_iter
+
+ @property
+ def max_epochs(self):
+ """int: Maximum training epochs."""
+ return self._max_epochs
+
+ @property
+ def max_iters(self):
+ """int: Maximum training iterations."""
+ return self._max_iters
+
+ @abstractmethod
+ def train(self):
+ pass
+
+ @abstractmethod
+ def val(self):
+ pass
+
+ @abstractmethod
+ def run(self, data_loaders, workflow, **kwargs):
+ pass
+
+ @abstractmethod
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl,
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ pass
+
+ def current_lr(self):
+ """Get current learning rates.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current learning rates of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+ if isinstance(self.optimizer, torch.optim.Optimizer):
+ lr = [group['lr'] for group in self.optimizer.param_groups]
+ elif isinstance(self.optimizer, dict):
+ lr = dict()
+ for name, optim in self.optimizer.items():
+ lr[name] = [group['lr'] for group in optim.param_groups]
+ else:
+ raise RuntimeError(
+ 'lr is not applicable because optimizer does not exist.')
+ return lr
+
+ def current_momentum(self):
+ """Get current momentums.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current momentums of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+
+ def _get_momentum(optimizer):
+ momentums = []
+ for group in optimizer.param_groups:
+ if 'momentum' in group.keys():
+ momentums.append(group['momentum'])
+ elif 'betas' in group.keys():
+ momentums.append(group['betas'][0])
+ else:
+ momentums.append(0)
+ return momentums
+
+ if self.optimizer is None:
+ raise RuntimeError(
+ 'momentum is not applicable because optimizer does not exist.')
+ elif isinstance(self.optimizer, torch.optim.Optimizer):
+ momentums = _get_momentum(self.optimizer)
+ elif isinstance(self.optimizer, dict):
+ momentums = dict()
+ for name, optim in self.optimizer.items():
+ momentums[name] = _get_momentum(optim)
+ return momentums
+
+ def register_hook(self, hook, priority='NORMAL'):
+ """Register a hook into the hook list.
+
+ The hook will be inserted into a priority queue, with the specified
+ priority (See :class:`Priority` for details of priorities).
+ For hooks with the same priority, they will be triggered in the same
+ order as they are registered.
+
+ Args:
+ hook (:obj:`Hook`): The hook to be registered.
+ priority (int or str or :obj:`Priority`): Hook priority.
+ Lower value means higher priority.
+ """
+ assert isinstance(hook, Hook)
+ if hasattr(hook, 'priority'):
+ raise ValueError('"priority" is a reserved attribute for hooks')
+ priority = get_priority(priority)
+ hook.priority = priority
+ # insert the hook to a sorted list
+ inserted = False
+ for i in range(len(self._hooks) - 1, -1, -1):
+ if priority >= self._hooks[i].priority:
+ self._hooks.insert(i + 1, hook)
+ inserted = True
+ break
+ if not inserted:
+ self._hooks.insert(0, hook)
+
+ def register_hook_from_cfg(self, hook_cfg):
+ """Register a hook from its cfg.
+
+ Args:
+ hook_cfg (dict): Hook config. It should have at least keys 'type'
+ and 'priority' indicating its type and priority.
+
+ Notes:
+ The specific hook class to register should not use 'type' and
+ 'priority' arguments during initialization.
+ """
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+ self.register_hook(hook, priority=priority)
+
+ def call_hook(self, fn_name):
+ """Call all hooks.
+
+ Args:
+ fn_name (str): The function name in each hook to be called, such as
+ "before_train_epoch".
+ """
+ for hook in self._hooks:
+ getattr(hook, fn_name)(self)
+
+ def get_hook_info(self):
+ # Get hooks info in each stage
+ stage_hook_map = {stage: [] for stage in Hook.stages}
+ for hook in self.hooks:
+ try:
+ priority = Priority(hook.priority).name
+ except ValueError:
+ priority = hook.priority
+ classname = hook.__class__.__name__
+ hook_info = f'({priority:<12}) {classname:<35}'
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
+
+ stage_hook_infos = []
+ for stage in Hook.stages:
+ hook_infos = stage_hook_map[stage]
+ if len(hook_infos) > 0:
+ info = f'{stage}:\n'
+ info += '\n'.join(hook_infos)
+ info += '\n -------------------- '
+ stage_hook_infos.append(info)
+ return '\n'.join(stage_hook_infos)
+
+ def load_checkpoint(self,
+ filename,
+ map_location='cpu',
+ strict=False,
+ revise_keys=[(r'^module.', '')]):
+ return load_checkpoint(
+ self.model,
+ filename,
+ map_location,
+ strict,
+ self.logger,
+ revise_keys=revise_keys)
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ if map_location == 'default':
+ if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(checkpoint)
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ if self.meta is None:
+ self.meta = {}
+ self.meta.setdefault('hook_msgs', {})
+ # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
+ self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
+
+ # Re-calculate the number of iterations when resuming
+ # models with different number of GPUs
+ if 'config' in checkpoint['meta']:
+ config = mmcv.Config.fromstring(
+ checkpoint['meta']['config'], file_format='.py')
+ previous_gpu_ids = config.get('gpu_ids', None)
+ if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
+ previous_gpu_ids) != self.world_size:
+ self._iter = int(self._iter * len(previous_gpu_ids) /
+ self.world_size)
+ self.logger.info('the iteration number is changed due to '
+ 'change of GPU number')
+
+ # resume meta information meta
+ self.meta = checkpoint['meta']
+
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+
+ def register_lr_hook(self, lr_config):
+ if lr_config is None:
+ return
+ elif isinstance(lr_config, dict):
+ assert 'policy' in lr_config
+ policy_type = lr_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of Lr updater.
+ # Since this is not applicable for `
+ # CosineAnnealingLrUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'LrUpdaterHook'
+ lr_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(lr_config, HOOKS)
+ else:
+ hook = lr_config
+ self.register_hook(hook, priority='VERY_HIGH')
+
+ def register_momentum_hook(self, momentum_config):
+ if momentum_config is None:
+ return
+ if isinstance(momentum_config, dict):
+ assert 'policy' in momentum_config
+ policy_type = momentum_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of momentum updater.
+ # Since this is not applicable for
+ # `CosineAnnealingMomentumUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'MomentumUpdaterHook'
+ momentum_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+ else:
+ hook = momentum_config
+ self.register_hook(hook, priority='HIGH')
+
+ def register_optimizer_hook(self, optimizer_config):
+ if optimizer_config is None:
+ return
+ if isinstance(optimizer_config, dict):
+ optimizer_config.setdefault('type', 'OptimizerHook')
+ hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+ else:
+ hook = optimizer_config
+ self.register_hook(hook, priority='ABOVE_NORMAL')
+
+ def register_checkpoint_hook(self, checkpoint_config):
+ if checkpoint_config is None:
+ return
+ if isinstance(checkpoint_config, dict):
+ checkpoint_config.setdefault('type', 'CheckpointHook')
+ hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+ else:
+ hook = checkpoint_config
+ self.register_hook(hook, priority='NORMAL')
+
+ def register_logger_hooks(self, log_config):
+ if log_config is None:
+ return
+ log_interval = log_config['interval']
+ for info in log_config['hooks']:
+ logger_hook = mmcv.build_from_cfg(
+ info, HOOKS, default_args=dict(interval=log_interval))
+ self.register_hook(logger_hook, priority='VERY_LOW')
+
+ def register_timer_hook(self, timer_config):
+ if timer_config is None:
+ return
+ if isinstance(timer_config, dict):
+ timer_config_ = copy.deepcopy(timer_config)
+ hook = mmcv.build_from_cfg(timer_config_, HOOKS)
+ else:
+ hook = timer_config
+ self.register_hook(hook, priority='LOW')
+
+ def register_custom_hooks(self, custom_config):
+ if custom_config is None:
+ return
+
+ if not isinstance(custom_config, list):
+ custom_config = [custom_config]
+
+ for item in custom_config:
+ if isinstance(item, dict):
+ self.register_hook_from_cfg(item)
+ else:
+ self.register_hook(item, priority='NORMAL')
+
+ def register_profiler_hook(self, profiler_config):
+ if profiler_config is None:
+ return
+ if isinstance(profiler_config, dict):
+ profiler_config.setdefault('type', 'ProfilerHook')
+ hook = mmcv.build_from_cfg(profiler_config, HOOKS)
+ else:
+ hook = profiler_config
+ self.register_hook(hook)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ timer_config=dict(type='IterTimerHook'),
+ custom_hooks_config=None):
+ """Register default and custom hooks for training.
+
+ Default and custom hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ self.register_lr_hook(lr_config)
+ self.register_momentum_hook(momentum_config)
+ self.register_optimizer_hook(optimizer_config)
+ self.register_checkpoint_hook(checkpoint_config)
+ self.register_timer_hook(timer_config)
+ self.register_logger_hooks(log_config)
+ self.register_custom_hooks(custom_hooks_config)
diff --git a/annotator/uniformer/mmcv/runner/builder.py b/annotator/uniformer/mmcv/runner/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/builder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from ..utils import Registry
+
+RUNNERS = Registry('runner')
+RUNNER_BUILDERS = Registry('runner builder')
+
+
+def build_runner_constructor(cfg):
+ return RUNNER_BUILDERS.build(cfg)
+
+
+def build_runner(cfg, default_args=None):
+ runner_cfg = copy.deepcopy(cfg)
+ constructor_type = runner_cfg.pop('constructor',
+ 'DefaultRunnerConstructor')
+ runner_constructor = build_runner_constructor(
+ dict(
+ type=constructor_type,
+ runner_cfg=runner_cfg,
+ default_args=default_args))
+ runner = runner_constructor()
+ return runner
diff --git a/annotator/uniformer/mmcv/runner/checkpoint.py b/annotator/uniformer/mmcv/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b29ca320679164432f446adad893e33fb2b4b29e
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/checkpoint.py
@@ -0,0 +1,707 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import re
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+
+import annotator.uniformer.mmcv as mmcv
+from ..fileio import FileClient
+from ..fileio import load as load_file
+from ..parallel import is_module_wrapper
+from ..utils import mkdir_or_exist
+from .dist_utils import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+class CheckpointLoader:
+ """A general checkpoint loader to manage all schemes."""
+
+ _schemes = {}
+
+ @classmethod
+ def _register_scheme(cls, prefixes, loader, force=False):
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if (prefix not in cls._schemes) or force:
+ cls._schemes[prefix] = loader
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a loader backend, '
+ 'add "force=True" if you want to override it')
+ # sort, longer prefixes take priority
+ cls._schemes = OrderedDict(
+ sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
+
+ @classmethod
+ def register_scheme(cls, prefixes, loader=None, force=False):
+ """Register a loader to CheckpointLoader.
+
+ This method can be used as a normal class method or a decorator.
+
+ Args:
+ prefixes (str or list[str] or tuple[str]):
+ The prefix of the registered loader.
+ loader (function, optional): The loader function to be registered.
+ When this method is used as a decorator, loader is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the loader
+ if the prefix has already been registered. Defaults to False.
+ """
+
+ if loader is not None:
+ cls._register_scheme(prefixes, loader, force=force)
+ return
+
+ def _register(loader_cls):
+ cls._register_scheme(prefixes, loader_cls, force=force)
+ return loader_cls
+
+ return _register
+
+ @classmethod
+ def _get_checkpoint_loader(cls, path):
+ """Finds a loader that supports the given path. Falls back to the local
+ loader if no other loader is found.
+
+ Args:
+ path (str): checkpoint path
+
+ Returns:
+ loader (function): checkpoint loader
+ """
+
+ for p in cls._schemes:
+ if path.startswith(p):
+ return cls._schemes[p]
+
+ @classmethod
+ def load_checkpoint(cls, filename, map_location=None, logger=None):
+ """load checkpoint through URL scheme path.
+
+ Args:
+ filename (str): checkpoint file name with given prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ logger (:mod:`logging.Logger`, optional): The logger for message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint_loader = cls._get_checkpoint_loader(filename)
+ class_name = checkpoint_loader.__name__
+ mmcv.print_log(
+ f'load checkpoint from {class_name[10:]} path: {filename}', logger)
+ return checkpoint_loader(filename, map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes='')
+def load_from_local(filename, map_location):
+ """load checkpoint by local file path.
+
+ Args:
+ filename (str): local checkpoint file path
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
+def load_from_http(filename, map_location=None, model_dir=None):
+ """load checkpoint through HTTP or HTTPS scheme path. In distributed
+ setting, this function only download checkpoint at local rank 0.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ model_dir (string, optional): directory in which to save the object,
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='pavi://')
+def load_from_pavi(filename, map_location=None):
+ """load checkpoint through the file path prefixed with pavi. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with pavi prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ assert filename.startswith('pavi://'), \
+ f'Expected filename startswith `pavi://`, but get {filename}'
+ model_path = filename[7:]
+
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='s3://')
+def load_from_ceph(filename, map_location=None, backend='petrel'):
+ """load checkpoint through the file path prefixed with s3. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with s3 prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ backend (str, optional): The storage backend type. Options are 'ceph',
+ 'petrel'. Default: 'petrel'.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ allowed_backends = ['ceph', 'petrel']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+
+ if backend == 'ceph':
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+
+ # CephClient and PetrelBackend have the same prefix 's3://' and the latter
+ # will be chosen as default. If PetrelBackend can not be instantiated
+ # successfully, the CephClient will be chosen.
+ try:
+ file_client = FileClient(backend=backend)
+ except ImportError:
+ allowed_backends.remove(backend)
+ file_client = FileClient(backend=allowed_backends[0])
+
+ with io.BytesIO(file_client.get(filename)) as buffer:
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
+def load_from_torchvision(filename, map_location=None):
+ """load checkpoint through the file path prefixed with modelzoo or
+ torchvision.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_torchvision_models()
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_name = filename[11:]
+ else:
+ model_name = filename[14:]
+ return load_from_http(model_urls[model_name], map_location=map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
+def load_from_openmmlab(filename, map_location=None):
+ """load checkpoint through the file path prefixed with open-mmlab or
+ openmmlab.
+
+ Args:
+ filename (str): checkpoint file path with open-mmlab or
+ openmmlab prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_external_models()
+ prefix_str = 'open-mmlab://'
+ if filename.startswith(prefix_str):
+ model_name = filename[13:]
+ else:
+ model_name = filename[12:]
+ prefix_str = 'openmmlab://'
+
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
+ f'of {prefix_str}{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_from_http(model_url, map_location=map_location)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='mmcls://')
+def load_from_mmcls(filename, map_location=None):
+ """load checkpoint through the file path prefixed with mmcls.
+
+ Args:
+ filename (str): checkpoint file path with mmcls prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_from_http(
+ model_urls[model_name], map_location=map_location)
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ return checkpoint
+
+
+def _load_checkpoint(filename, map_location=None, logger=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None.
+ logger (:mod:`logging.Logger`, optional): The logger for error message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ return CheckpointLoader.load_checkpoint(filename, map_location, logger)
+
+
+def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
+ """Load partial pretrained model with specific prefix.
+
+ Args:
+ prefix (str): The prefix of sub-module.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint = _load_checkpoint(filename, map_location=map_location)
+
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if not prefix.endswith('.'):
+ prefix += '.'
+ prefix_len = len(prefix)
+
+ state_dict = {
+ k[prefix_len:]: v
+ for k, v in state_dict.items() if k.startswith(prefix)
+ }
+
+ assert state_dict, f'{prefix} is not in the pretrained model'
+ return state_dict
+
+
+def load_checkpoint(model,
+ filename,
+ map_location=None,
+ strict=False,
+ logger=None,
+ revise_keys=[(r'^module\.', '')]):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ revise_keys (list): A list of customized keywords to modify the
+ state_dict in checkpoint. Each item is a (pattern, replacement)
+ pair of the regular expression operations. Default: strip
+ the prefix 'module.' by [(r'^module\\.', '')].
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location, logger)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ # strip prefix of state_dict
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
+ for p, r in revise_keys:
+ state_dict = OrderedDict(
+ {re.sub(p, r, k): v
+ for k, v in state_dict.items()})
+ # Keep metadata in state_dict
+ state_dict._metadata = metadata
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ # Keep metadata in state_dict
+ state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model,
+ filename,
+ optimizer=None,
+ meta=None,
+ file_client_args=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ if file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" if filename starts with'
+ f'"pavi://", but got {file_client_args}')
+ try:
+ from pavi import modelcloud
+ from pavi import exception
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except exception.NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with io.BytesIO() as f:
+ torch.save(checkpoint, f)
+ file_client.put(f.getvalue(), filename)
diff --git a/annotator/uniformer/mmcv/runner/default_constructor.py b/annotator/uniformer/mmcv/runner/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1f5b44168768dfda3947393a63a6cf9cf50b41
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/default_constructor.py
@@ -0,0 +1,44 @@
+from .builder import RUNNER_BUILDERS, RUNNERS
+
+
+@RUNNER_BUILDERS.register_module()
+class DefaultRunnerConstructor:
+ """Default constructor for runners.
+
+ Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
+ For example, We can inject some new properties and functions for `Runner`.
+
+ Example:
+ >>> from annotator.uniformer.mmcv.runner import RUNNER_BUILDERS, build_runner
+ >>> # Define a new RunnerReconstructor
+ >>> @RUNNER_BUILDERS.register_module()
+ >>> class MyRunnerConstructor:
+ ... def __init__(self, runner_cfg, default_args=None):
+ ... if not isinstance(runner_cfg, dict):
+ ... raise TypeError('runner_cfg should be a dict',
+ ... f'but got {type(runner_cfg)}')
+ ... self.runner_cfg = runner_cfg
+ ... self.default_args = default_args
+ ...
+ ... def __call__(self):
+ ... runner = RUNNERS.build(self.runner_cfg,
+ ... default_args=self.default_args)
+ ... # Add new properties for existing runner
+ ... runner.my_name = 'my_runner'
+ ... runner.my_function = lambda self: print(self.my_name)
+ ... ...
+ >>> # build your runner
+ >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
+ ... constructor='MyRunnerConstructor')
+ >>> runner = build_runner(runner_cfg)
+ """
+
+ def __init__(self, runner_cfg, default_args=None):
+ if not isinstance(runner_cfg, dict):
+ raise TypeError('runner_cfg should be a dict',
+ f'but got {type(runner_cfg)}')
+ self.runner_cfg = runner_cfg
+ self.default_args = default_args
+
+ def __call__(self):
+ return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
diff --git a/annotator/uniformer/mmcv/runner/dist_utils.py b/annotator/uniformer/mmcv/runner/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/dist_utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import os
+import subprocess
+from collections import OrderedDict
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce parameters.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters or buffers of a
+ model.
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ params = [param.data for param in params]
+ if coalesce:
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
+ else:
+ for tensor in params:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
diff --git a/annotator/uniformer/mmcv/runner/epoch_based_runner.py b/annotator/uniformer/mmcv/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..766a9ce6afdf09cd11b1b15005f5132583011348
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/epoch_based_runner.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .utils import get_host_info
+
+
+@RUNNERS.register_module()
+class EpochBasedRunner(BaseRunner):
+ """Epoch-based Runner.
+
+ This runner train models epoch by epoch.
+ """
+
+ def run_iter(self, data_batch, train_mode, **kwargs):
+ if self.batch_processor is not None:
+ outputs = self.batch_processor(
+ self.model, data_batch, train_mode=train_mode, **kwargs)
+ elif train_mode:
+ outputs = self.model.train_step(data_batch, self.optimizer,
+ **kwargs)
+ else:
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('"batch_processor()" or "model.train_step()"'
+ 'and "model.val_step()" must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._max_iters = self._max_epochs * len(self.data_loader)
+ self.call_hook('before_train_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_train_iter')
+ self.run_iter(data_batch, train_mode=True, **kwargs)
+ self.call_hook('after_train_iter')
+ self._iter += 1
+
+ self.call_hook('after_train_epoch')
+ self._epoch += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ self.call_hook('before_val_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_val_iter')
+ self.run_iter(data_batch, train_mode=False)
+ self.call_hook('after_val_iter')
+
+ self.call_hook('after_val_epoch')
+
+ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
+ running order and epochs. E.g, [('train', 2), ('val', 1)] means
+ running 2 epochs for training and 1 epoch for validation,
+ iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_epochs is not None:
+ warnings.warn(
+ 'setting max_epochs in run is deprecated, '
+ 'please set max_epochs in runner_config', DeprecationWarning)
+ self._max_epochs = max_epochs
+
+ assert self._max_epochs is not None, (
+ 'max_epochs must be specified during instantiation')
+
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if mode == 'train':
+ self._max_iters = self._max_epochs * len(data_loaders[i])
+ break
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d epochs', workflow,
+ self._max_epochs)
+ self.call_hook('before_run')
+
+ while self.epoch < self._max_epochs:
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if isinstance(mode, str): # self.train()
+ if not hasattr(self, mode):
+ raise ValueError(
+ f'runner has no method named "{mode}" to run an '
+ 'epoch')
+ epoch_runner = getattr(self, mode)
+ else:
+ raise TypeError(
+ 'mode in workflow must be a str, but got {}'.format(
+ type(mode)))
+
+ for _ in range(epochs):
+ if mode == 'train' and self.epoch >= self._max_epochs:
+ break
+ epoch_runner(data_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_run')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='epoch_{}.pth',
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ """Save the checkpoint.
+
+ Args:
+ out_dir (str): The directory that checkpoints are saved.
+ filename_tmpl (str, optional): The checkpoint filename template,
+ which contains a placeholder for the epoch number.
+ Defaults to 'epoch_{}.pth'.
+ save_optimizer (bool, optional): Whether to save the optimizer to
+ the checkpoint. Defaults to True.
+ meta (dict, optional): The meta information to be saved in the
+ checkpoint. Defaults to None.
+ create_symlink (bool, optional): Whether to create a symlink
+ "latest.pth" to point to the latest checkpoint.
+ Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+
+@RUNNERS.register_module()
+class Runner(EpochBasedRunner):
+ """Deprecated name of EpochBasedRunner."""
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Runner was deprecated, please use EpochBasedRunner instead')
+ super().__init__(*args, **kwargs)
diff --git a/annotator/uniformer/mmcv/runner/fp16_utils.py b/annotator/uniformer/mmcv/runner/fp16_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1981011d6859192e3e663e29d13500d56ba47f6c
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/fp16_utils.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import warnings
+from collections import abc
+from inspect import getfullargspec
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .dist_utils import allreduce_grads as _allreduce_grads
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
+ # manually, so the behavior may not be consistent with real amp.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+
+
+def cast_tensor_type(inputs, src_type, dst_type):
+ """Recursively convert Tensor in inputs from src_type to dst_type.
+
+ Args:
+ inputs: Inputs that to be casted.
+ src_type (torch.dtype): Source type..
+ dst_type (torch.dtype): Destination type.
+
+ Returns:
+ The same type with inputs, but all contained Tensors have been cast.
+ """
+ if isinstance(inputs, nn.Module):
+ return inputs
+ elif isinstance(inputs, torch.Tensor):
+ return inputs.to(dst_type)
+ elif isinstance(inputs, str):
+ return inputs
+ elif isinstance(inputs, np.ndarray):
+ return inputs
+ elif isinstance(inputs, abc.Mapping):
+ return type(inputs)({
+ k: cast_tensor_type(v, src_type, dst_type)
+ for k, v in inputs.items()
+ })
+ elif isinstance(inputs, abc.Iterable):
+ return type(inputs)(
+ cast_tensor_type(item, src_type, dst_type) for item in inputs)
+ else:
+ return inputs
+
+
+def auto_fp16(apply_to=None, out_fp32=False):
+ """Decorator to enable fp16 training automatically.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If inputs arguments are fp32 tensors, they will
+ be converted to fp16 automatically. Arguments other than fp32 tensors are
+ ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp32 (bool): Whether to convert the output back to fp32.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp16
+ >>> @auto_fp16()
+ >>> def forward(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp16
+ >>> @auto_fp16(apply_to=('pred', ))
+ >>> def do_something(self, pred, others):
+ >>> pass
+ """
+
+ def auto_fp16_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@auto_fp16 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ # NOTE: default args are not taken into consideration
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.float, torch.half))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = {}
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.float, torch.half)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=True):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp32:
+ output = cast_tensor_type(output, torch.half, torch.float)
+ return output
+
+ return new_func
+
+ return auto_fp16_wrapper
+
+
+def force_fp32(apply_to=None, out_fp16=False):
+ """Decorator to convert input arguments to fp32 in force.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If there are some inputs that must be processed
+ in fp32 mode, then this decorator can handle it. If inputs arguments are
+ fp16 tensors, they will be converted to fp32 automatically. Arguments other
+ than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
+ torch.cuda.amp is used as the backend, otherwise, original mmcv
+ implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp16 (bool): Whether to convert the output back to fp16.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp32
+ >>> @force_fp32()
+ >>> def loss(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp32
+ >>> @force_fp32(apply_to=('pred', ))
+ >>> def post_process(self, pred, others):
+ >>> pass
+ """
+
+ def force_fp32_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@force_fp32 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.half, torch.float))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = dict()
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.half, torch.float)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=False):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp16:
+ output = cast_tensor_type(output, torch.float, torch.half)
+ return output
+
+ return new_func
+
+ return force_fp32_wrapper
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ warnings.warning(
+ '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
+ 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
+ _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
+
+
+def wrap_fp16_model(model):
+ """Wrap the FP32 model to FP16.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ For PyTorch >= 1.6, this function will
+ 1. Set fp16 flag inside the model to True.
+
+ Otherwise:
+ 1. Convert FP32 model to FP16.
+ 2. Remain some necessary layers to be FP32, e.g., normalization layers.
+ 3. Set `fp16_enabled` flag inside the model to True.
+
+ Args:
+ model (nn.Module): Model in FP32.
+ """
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
+ # convert model to fp16
+ model.half()
+ # patch the normalization layers to make it work in fp32 mode
+ patch_norm_fp32(model)
+ # set `fp16_enabled` flag
+ for m in model.modules():
+ if hasattr(m, 'fp16_enabled'):
+ m.fp16_enabled = True
+
+
+def patch_norm_fp32(module):
+ """Recursively convert normalization layers from FP16 to FP32.
+
+ Args:
+ module (nn.Module): The modules to be converted in FP16.
+
+ Returns:
+ nn.Module: The converted module, the normalization layers have been
+ converted to FP32.
+ """
+ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+ module.float()
+ if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
+ module.forward = patch_forward_method(module.forward, torch.half,
+ torch.float)
+ for child in module.children():
+ patch_norm_fp32(child)
+ return module
+
+
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+ """Patch the forward method of a module.
+
+ Args:
+ func (callable): The original forward method.
+ src_type (torch.dtype): Type of input arguments to be converted from.
+ dst_type (torch.dtype): Type of input arguments to be converted to.
+ convert_output (bool): Whether to convert the output back to src_type.
+
+ Returns:
+ callable: The patched forward method.
+ """
+
+ def new_forward(*args, **kwargs):
+ output = func(*cast_tensor_type(args, src_type, dst_type),
+ **cast_tensor_type(kwargs, src_type, dst_type))
+ if convert_output:
+ output = cast_tensor_type(output, dst_type, src_type)
+ return output
+
+ return new_forward
+
+
+class LossScaler:
+ """Class that manages loss scaling in mixed precision training which
+ supports both dynamic or static mode.
+
+ The implementation refers to
+ https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
+ Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
+ It's important to understand how :class:`LossScaler` operates.
+ Loss scaling is designed to combat the problem of underflowing
+ gradients encountered at long times when training fp16 networks.
+ Dynamic loss scaling begins by attempting a very high loss
+ scale. Ironically, this may result in OVERflowing gradients.
+ If overflowing gradients are encountered, :class:`FP16_Optimizer` then
+ skips the update step for this particular iteration/minibatch,
+ and :class:`LossScaler` adjusts the loss scale to a lower value.
+ If a certain number of iterations occur without overflowing gradients
+ detected,:class:`LossScaler` increases the loss scale once more.
+ In this way :class:`LossScaler` attempts to "ride the edge" of always
+ using the highest loss scale possible without incurring overflow.
+
+ Args:
+ init_scale (float): Initial loss scale value, default: 2**32.
+ scale_factor (float): Factor used when adjusting the loss scale.
+ Default: 2.
+ mode (str): Loss scaling mode. 'dynamic' or 'static'
+ scale_window (int): Number of consecutive iterations without an
+ overflow to wait before increasing the loss scale. Default: 1000.
+ """
+
+ def __init__(self,
+ init_scale=2**32,
+ mode='dynamic',
+ scale_factor=2.,
+ scale_window=1000):
+ self.cur_scale = init_scale
+ self.cur_iter = 0
+ assert mode in ('dynamic',
+ 'static'), 'mode can only be dynamic or static'
+ self.mode = mode
+ self.last_overflow_iter = -1
+ self.scale_factor = scale_factor
+ self.scale_window = scale_window
+
+ def has_overflow(self, params):
+ """Check if params contain overflow."""
+ if self.mode != 'dynamic':
+ return False
+ for p in params:
+ if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
+ return True
+ return False
+
+ def _has_inf_or_nan(x):
+ """Check if params contain NaN."""
+ try:
+ cpu_sum = float(x.float().sum())
+ except RuntimeError as instance:
+ if 'value cannot be converted' not in instance.args[0]:
+ raise
+ return True
+ else:
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') \
+ or cpu_sum != cpu_sum:
+ return True
+ return False
+
+ def update_scale(self, overflow):
+ """update the current loss scale value when overflow happens."""
+ if self.mode != 'dynamic':
+ return
+ if overflow:
+ self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
+ self.last_overflow_iter = self.cur_iter
+ else:
+ if (self.cur_iter - self.last_overflow_iter) % \
+ self.scale_window == 0:
+ self.cur_scale *= self.scale_factor
+ self.cur_iter += 1
+
+ def state_dict(self):
+ """Returns the state of the scaler as a :class:`dict`."""
+ return dict(
+ cur_scale=self.cur_scale,
+ cur_iter=self.cur_iter,
+ mode=self.mode,
+ last_overflow_iter=self.last_overflow_iter,
+ scale_factor=self.scale_factor,
+ scale_window=self.scale_window)
+
+ def load_state_dict(self, state_dict):
+ """Loads the loss_scaler state dict.
+
+ Args:
+ state_dict (dict): scaler state.
+ """
+ self.cur_scale = state_dict['cur_scale']
+ self.cur_iter = state_dict['cur_iter']
+ self.mode = state_dict['mode']
+ self.last_overflow_iter = state_dict['last_overflow_iter']
+ self.scale_factor = state_dict['scale_factor']
+ self.scale_window = state_dict['scale_window']
+
+ @property
+ def loss_scale(self):
+ return self.cur_scale
diff --git a/annotator/uniformer/mmcv/runner/hooks/__init__.py b/annotator/uniformer/mmcv/runner/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .ema import EMAHook
+from .evaluation import DistEvalHook, EvalHook
+from .hook import HOOKS, Hook
+from .iter_timer import IterTimerHook
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+ TextLoggerHook, WandbLoggerHook)
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .momentum_updater import MomentumUpdaterHook
+from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, OptimizerHook)
+from .profiler import ProfilerHook
+from .sampler_seed import DistSamplerSeedHook
+from .sync_buffer import SyncBuffersHook
+
+__all__ = [
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
+ 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+ 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook'
+]
diff --git a/annotator/uniformer/mmcv/runner/hooks/checkpoint.py b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af3fae43ac4b35532641a81eb13557edfc7dfba
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+from annotator.uniformer.mmcv.fileio import FileClient
+from ..dist_utils import allreduce_params, master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class CheckpointHook(Hook):
+ """Save checkpoints periodically.
+
+ Args:
+ interval (int): The saving period. If ``by_epoch=True``, interval
+ indicates epochs, otherwise it indicates iterations.
+ Default: -1, which means "never".
+ by_epoch (bool): Saving checkpoints by epoch or by iteration.
+ Default: True.
+ save_optimizer (bool): Whether to save optimizer state_dict in the
+ checkpoint. It is usually used for resuming experiments.
+ Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, ``runner.work_dir`` will be used by default. If
+ specified, the ``out_dir`` will be the concatenation of ``out_dir``
+ and the last level directory of ``runner.work_dir``.
+ `Changed in version 1.3.16.`
+ max_keep_ckpts (int, optional): The maximum checkpoints to keep.
+ In some cases we want only the latest few checkpoints and would
+ like to delete old ones to save the disk space.
+ Default: -1, which means unlimited.
+ save_last (bool, optional): Whether to force the last checkpoint to be
+ saved regardless of interval. Default: True.
+ sync_buffer (bool, optional): Whether to synchronize buffers in
+ different gpus. Default: False.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+
+ .. warning::
+ Before v1.3.16, the ``out_dir`` argument indicates the path where the
+ checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
+ root directory and the final path to save checkpoint is the
+ concatenation of ``out_dir`` and the last level directory of
+ ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
+ and the value of ``runner.work_dir`` is "/path/of/B", then the final
+ path will be "/path/of/A/B".
+ """
+
+ def __init__(self,
+ interval=-1,
+ by_epoch=True,
+ save_optimizer=True,
+ out_dir=None,
+ max_keep_ckpts=-1,
+ save_last=True,
+ sync_buffer=False,
+ file_client_args=None,
+ **kwargs):
+ self.interval = interval
+ self.by_epoch = by_epoch
+ self.save_optimizer = save_optimizer
+ self.out_dir = out_dir
+ self.max_keep_ckpts = max_keep_ckpts
+ self.save_last = save_last
+ self.args = kwargs
+ self.sync_buffer = sync_buffer
+ self.file_client_args = file_client_args
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+
+ runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
+ f'{self.file_client.name}.'))
+
+ # disable the create_symlink option because some file backends do not
+ # allow to create a symlink
+ if 'create_symlink' in self.args:
+ if self.args[
+ 'create_symlink'] and not self.file_client.allow_symlink:
+ self.args['create_symlink'] = False
+ warnings.warn(
+ ('create_symlink is set as True by the user but is changed'
+ 'to be False because creating symbolic link is not '
+ f'allowed in {self.file_client.name}'))
+ else:
+ self.args['create_symlink'] = self.file_client.allow_symlink
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` epochs
+ # 2. reach the last epoch of training
+ if self.every_n_epochs(
+ runner, self.interval) or (self.save_last
+ and self.is_last_epoch(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.epoch + 1} epochs')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
+
+ @master_only
+ def _save_checkpoint(self, runner):
+ """Save the current checkpoint and delete unwanted checkpoint."""
+ runner.save_checkpoint(
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args)
+ if runner.meta is not None:
+ if self.by_epoch:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+ else:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+ runner.meta.setdefault('hook_msgs', dict())
+ runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
+ self.out_dir, cur_ckpt_filename)
+ # remove other checkpoints
+ if self.max_keep_ckpts > 0:
+ if self.by_epoch:
+ name = 'epoch_{}.pth'
+ current_ckpt = runner.epoch + 1
+ else:
+ name = 'iter_{}.pth'
+ current_ckpt = runner.iter + 1
+ redundant_ckpts = range(
+ current_ckpt - self.max_keep_ckpts * self.interval, 0,
+ -self.interval)
+ filename_tmpl = self.args.get('filename_tmpl', name)
+ for _step in redundant_ckpts:
+ ckpt_path = self.file_client.join_path(
+ self.out_dir, filename_tmpl.format(_step))
+ if self.file_client.isfile(ckpt_path):
+ self.file_client.remove(ckpt_path)
+ else:
+ break
+
+ def after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` iterations
+ # 2. reach the last iteration of training
+ if self.every_n_iters(
+ runner, self.interval) or (self.save_last
+ and self.is_last_iter(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.iter + 1} iterations')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
diff --git a/annotator/uniformer/mmcv/runner/hooks/closure.py b/annotator/uniformer/mmcv/runner/hooks/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/closure.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ClosureHook(Hook):
+
+ def __init__(self, fn_name, fn):
+ assert hasattr(self, fn_name)
+ assert callable(fn)
+ setattr(self, fn_name, fn)
diff --git a/annotator/uniformer/mmcv/runner/hooks/ema.py b/annotator/uniformer/mmcv/runner/hooks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/ema.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...parallel import is_module_wrapper
+from ..hooks.hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EMAHook(Hook):
+ r"""Exponential Moving Average Hook.
+
+ Use Exponential Moving Average on all parameters of model in training
+ process. All parameters have a ema backup, which update by the formula
+ as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
+
+ .. math::
+
+ \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
+ \text{Xema\_{t}} + \text{momentum} \times X_t
+
+ Args:
+ momentum (float): The momentum used for updating ema parameter.
+ Defaults to 0.0002.
+ interval (int): Update ema parameter every interval iteration.
+ Defaults to 1.
+ warm_up (int): During first warm_up steps, we may use smaller momentum
+ to update ema parameters more slowly. Defaults to 100.
+ resume_from (str): The checkpoint path. Defaults to None.
+ """
+
+ def __init__(self,
+ momentum=0.0002,
+ interval=1,
+ warm_up=100,
+ resume_from=None):
+ assert isinstance(interval, int) and interval > 0
+ self.warm_up = warm_up
+ self.interval = interval
+ assert momentum > 0 and momentum < 1
+ self.momentum = momentum**interval
+ self.checkpoint = resume_from
+
+ def before_run(self, runner):
+ """To resume model with it's ema parameters more friendly.
+
+ Register ema parameter as ``named_buffer`` to model
+ """
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ self.param_ema_buffer = {}
+ self.model_parameters = dict(model.named_parameters(recurse=True))
+ for name, value in self.model_parameters.items():
+ # "." is not allowed in module's buffer name
+ buffer_name = f"ema_{name.replace('.', '_')}"
+ self.param_ema_buffer[name] = buffer_name
+ model.register_buffer(buffer_name, value.data.clone())
+ self.model_buffers = dict(model.named_buffers(recurse=True))
+ if self.checkpoint is not None:
+ runner.resume(self.checkpoint)
+
+ def after_train_iter(self, runner):
+ """Update ema parameter every self.interval iterations."""
+ curr_step = runner.iter
+ # We warm up the momentum considering the instability at beginning
+ momentum = min(self.momentum,
+ (1 + curr_step) / (self.warm_up + curr_step))
+ if curr_step % self.interval != 0:
+ return
+ for name, parameter in self.model_parameters.items():
+ buffer_name = self.param_ema_buffer[name]
+ buffer_parameter = self.model_buffers[buffer_name]
+ buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
+
+ def after_train_epoch(self, runner):
+ """We load parameter values from ema backup to model before the
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def before_train_epoch(self, runner):
+ """We recover model's parameter from ema backup after last epoch's
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def _swap_ema_parameters(self):
+ """Swap the parameter of model with parameter in ema_buffer."""
+ for name, value in self.model_parameters.items():
+ temp = value.data.clone()
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+ value.data.copy_(ema_buffer.data)
+ ema_buffer.data.copy_(temp)
diff --git a/annotator/uniformer/mmcv/runner/hooks/evaluation.py b/annotator/uniformer/mmcv/runner/hooks/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d00999ce5665c53bded8de9e084943eee2d230d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/evaluation.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from math import inf
+
+import torch.distributed as dist
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.utils import is_seq_of
+from .hook import Hook
+from .logger import LoggerHook
+
+
+class EvalHook(Hook):
+ """Non-Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in non-distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader, and return the test results. If ``None``, the default
+ test function ``mmcv.engine.single_gpu_test`` will be used.
+ (default: ``None``)
+ greater_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'greater' comparison rule. If ``None``,
+ _default_greater_keys will be used. (default: ``None``)
+ less_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
+ will be used. (default: ``None``)
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ `New in version 1.3.16.`
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ `New in version 1.3.16.`
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+
+ Notes:
+ If new arguments are added for EvalHook, tools/test.py,
+ tools/eval_metric.py may be affected.
+ """
+
+ # Since the key for determine greater or less is related to the downstream
+ # tasks, downstream repos may need to overwrite the following inner
+ # variable accordingly.
+
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+ init_value_map = {'greater': -inf, 'less': inf}
+ _default_greater_keys = [
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+ 'mAcc', 'aAcc'
+ ]
+ _default_less_keys = ['loss']
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
+ f'but got {type(dataloader)}')
+
+ if interval <= 0:
+ raise ValueError(f'interval must be a positive number, '
+ f'but got {interval}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
+
+ if start is not None and start < 0:
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
+ f'than 0')
+
+ self.dataloader = dataloader
+ self.interval = interval
+ self.start = start
+ self.by_epoch = by_epoch
+
+ assert isinstance(save_best, str) or save_best is None, \
+ '""save_best"" should be a str or None ' \
+ f'rather than {type(save_best)}'
+ self.save_best = save_best
+ self.eval_kwargs = eval_kwargs
+ self.initial_flag = True
+
+ if test_fn is None:
+ from annotator.uniformer.mmcv.engine import single_gpu_test
+ self.test_fn = single_gpu_test
+ else:
+ self.test_fn = test_fn
+
+ if greater_keys is None:
+ self.greater_keys = self._default_greater_keys
+ else:
+ if not isinstance(greater_keys, (list, tuple)):
+ greater_keys = (greater_keys, )
+ assert is_seq_of(greater_keys, str)
+ self.greater_keys = greater_keys
+
+ if less_keys is None:
+ self.less_keys = self._default_less_keys
+ else:
+ if not isinstance(less_keys, (list, tuple)):
+ less_keys = (less_keys, )
+ assert is_seq_of(less_keys, str)
+ self.less_keys = less_keys
+
+ if self.save_best is not None:
+ self.best_ckpt_path = None
+ self._init_rule(rule, self.save_best)
+
+ self.out_dir = out_dir
+ self.file_client_args = file_client_args
+
+ def _init_rule(self, rule, key_indicator):
+ """Initialize rule, key_indicator, comparison_func, and best score.
+
+ Here is the rule to determine which rule is used for key indicator
+ when the rule is not specific (note that the key indicator matching
+ is case-insensitive):
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
+ specified as 'greater'.
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
+ specified as 'less'.
+ 3. Or if the key indicator is equal to the substring in any one item
+ in ``self.greater_keys``, the rule will be specified as 'greater'.
+ 4. Or if the key indicator is equal to the substring in any one item
+ in ``self.less_keys``, the rule will be specified as 'less'.
+
+ Args:
+ rule (str | None): Comparison rule for best score.
+ key_indicator (str | None): Key indicator to determine the
+ comparison rule.
+ """
+ if rule not in self.rule_map and rule is not None:
+ raise KeyError(f'rule must be greater, less or None, '
+ f'but got {rule}.')
+
+ if rule is None:
+ if key_indicator != 'auto':
+ # `_lc` here means we use the lower case of keys for
+ # case-insensitive matching
+ key_indicator_lc = key_indicator.lower()
+ greater_keys = [key.lower() for key in self.greater_keys]
+ less_keys = [key.lower() for key in self.less_keys]
+
+ if key_indicator_lc in greater_keys:
+ rule = 'greater'
+ elif key_indicator_lc in less_keys:
+ rule = 'less'
+ elif any(key in key_indicator_lc for key in greater_keys):
+ rule = 'greater'
+ elif any(key in key_indicator_lc for key in less_keys):
+ rule = 'less'
+ else:
+ raise ValueError(f'Cannot infer the rule for key '
+ f'{key_indicator}, thus a specific rule '
+ f'must be specified.')
+ self.rule = rule
+ self.key_indicator = key_indicator
+ if self.rule is not None:
+ self.compare_func = self.rule_map[self.rule]
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'The best checkpoint will be saved to {self.out_dir} by '
+ f'{self.file_client.name}'))
+
+ if self.save_best is not None:
+ if runner.meta is None:
+ warnings.warn('runner.meta is None. Creating an empty one.')
+ runner.meta = dict()
+ runner.meta.setdefault('hook_msgs', dict())
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
+ 'best_ckpt', None)
+
+ def before_train_iter(self, runner):
+ """Evaluate the model only at the start of training by iteration."""
+ if self.by_epoch or not self.initial_flag:
+ return
+ if self.start is not None and runner.iter >= self.start:
+ self.after_train_iter(runner)
+ self.initial_flag = False
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ if not (self.by_epoch and self.initial_flag):
+ return
+ if self.start is not None and runner.epoch >= self.start:
+ self.after_train_epoch(runner)
+ self.initial_flag = False
+
+ def after_train_iter(self, runner):
+ """Called after every training iter to evaluate the results."""
+ if not self.by_epoch and self._should_evaluate(runner):
+ # Because the priority of EvalHook is higher than LoggerHook, the
+ # training log and the evaluating log are mixed. Therefore,
+ # we need to dump the training log and clear it before evaluating
+ # log is generated. In addition, this problem will only appear in
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
+ # the training log has been printed, so it will not cause any
+ # problem. more details at
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
+ for hook in runner._hooks:
+ if isinstance(hook, LoggerHook):
+ hook.after_train_iter(runner)
+ runner.log_buffer.clear()
+
+ self._do_evaluate(runner)
+
+ def after_train_epoch(self, runner):
+ """Called after every training epoch to evaluate the results."""
+ if self.by_epoch and self._should_evaluate(runner):
+ self._do_evaluate(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ results = self.test_fn(runner.model, self.dataloader)
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to save
+ # the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
+
+ def _should_evaluate(self, runner):
+ """Judge whether to perform evaluation.
+
+ Here is the rule to judge whether to perform evaluation:
+ 1. It will not perform evaluation during the epoch/iteration interval,
+ which is determined by ``self.interval``.
+ 2. It will not perform evaluation if the start time is larger than
+ current time.
+ 3. It will not perform evaluation when current time is larger than
+ the start time but during epoch/iteration interval.
+
+ Returns:
+ bool: The flag indicating whether to perform evaluation.
+ """
+ if self.by_epoch:
+ current = runner.epoch
+ check_time = self.every_n_epochs
+ else:
+ current = runner.iter
+ check_time = self.every_n_iters
+
+ if self.start is None:
+ if not check_time(runner, self.interval):
+ # No evaluation during the interval.
+ return False
+ elif (current + 1) < self.start:
+ # No evaluation if start is larger than the current time.
+ return False
+ else:
+ # Evaluation only at epochs/iters 3, 5, 7...
+ # if start==3 and interval==2
+ if (current + 1 - self.start) % self.interval:
+ return False
+ return True
+
+ def _save_ckpt(self, runner, key_score):
+ """Save the best checkpoint.
+
+ It will compare the score according to the compare function, write
+ related information (best score, best checkpoint path) and save the
+ best checkpoint into ``work_dir``.
+ """
+ if self.by_epoch:
+ current = f'epoch_{runner.epoch + 1}'
+ cur_type, cur_time = 'epoch', runner.epoch + 1
+ else:
+ current = f'iter_{runner.iter + 1}'
+ cur_type, cur_time = 'iter', runner.iter + 1
+
+ best_score = runner.meta['hook_msgs'].get(
+ 'best_score', self.init_value_map[self.rule])
+ if self.compare_func(key_score, best_score):
+ best_score = key_score
+ runner.meta['hook_msgs']['best_score'] = best_score
+
+ if self.best_ckpt_path and self.file_client.isfile(
+ self.best_ckpt_path):
+ self.file_client.remove(self.best_ckpt_path)
+ runner.logger.info(
+ (f'The previous best checkpoint {self.best_ckpt_path} was '
+ 'removed'))
+
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
+ self.best_ckpt_path = self.file_client.join_path(
+ self.out_dir, best_ckpt_name)
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
+
+ runner.save_checkpoint(
+ self.out_dir, best_ckpt_name, create_symlink=False)
+ runner.logger.info(
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
+ runner.logger.info(
+ f'Best {self.key_indicator} is {best_score:0.4f} '
+ f'at {cur_time} {cur_type}.')
+
+ def evaluate(self, runner, results):
+ """Evaluate the results.
+
+ Args:
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
+ results (list): Output results.
+ """
+ eval_res = self.dataloader.dataset.evaluate(
+ results, logger=runner.logger, **self.eval_kwargs)
+
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+
+ if self.save_best is not None:
+ # If the performance of model is pool, the `eval_res` may be an
+ # empty dict and it will raise exception when `self.save_best` is
+ # not None. More details at
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
+ if not eval_res:
+ warnings.warn(
+ 'Since `eval_res` is an empty dict, the behavior to save '
+ 'the best checkpoint will be skipped in this evaluation.')
+ return None
+
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+
+ return None
+
+
+class DistEvalHook(EvalHook):
+ """Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader in a multi-gpu manner, and return the test results. If
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+ will be used. (default: ``None``)
+ tmpdir (str | None): Temporary directory to save the results of all
+ processes. Default: None.
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
+ Default: False.
+ broadcast_bn_buffer (bool): Whether to broadcast the
+ buffer(running_mean and running_var) of rank 0 to other rank
+ before evaluation. Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ broadcast_bn_buffer=True,
+ tmpdir=None,
+ gpu_collect=False,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+
+ if test_fn is None:
+ from annotator.uniformer.mmcv.engine import multi_gpu_test
+ test_fn = multi_gpu_test
+
+ super().__init__(
+ dataloader,
+ start=start,
+ interval=interval,
+ by_epoch=by_epoch,
+ save_best=save_best,
+ rule=rule,
+ test_fn=test_fn,
+ greater_keys=greater_keys,
+ less_keys=less_keys,
+ out_dir=out_dir,
+ file_client_args=file_client_args,
+ **eval_kwargs)
+
+ self.broadcast_bn_buffer = broadcast_bn_buffer
+ self.tmpdir = tmpdir
+ self.gpu_collect = gpu_collect
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+
+ results = self.test_fn(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to
+ # save the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
diff --git a/annotator/uniformer/mmcv/runner/hooks/hook.py b/annotator/uniformer/mmcv/runner/hooks/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8855c107727ecf85b917c890fc8b7f6359238a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/hook.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry, is_method_overridden
+
+HOOKS = Registry('hook')
+
+
+class Hook:
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
+ 'after_run')
+
+ def before_run(self, runner):
+ pass
+
+ def after_run(self, runner):
+ pass
+
+ def before_epoch(self, runner):
+ pass
+
+ def after_epoch(self, runner):
+ pass
+
+ def before_iter(self, runner):
+ pass
+
+ def after_iter(self, runner):
+ pass
+
+ def before_train_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def before_val_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def after_train_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def after_val_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self.before_iter(runner)
+
+ def before_val_iter(self, runner):
+ self.before_iter(runner)
+
+ def after_train_iter(self, runner):
+ self.after_iter(runner)
+
+ def after_val_iter(self, runner):
+ self.after_iter(runner)
+
+ def every_n_epochs(self, runner, n):
+ return (runner.epoch + 1) % n == 0 if n > 0 else False
+
+ def every_n_inner_iters(self, runner, n):
+ return (runner.inner_iter + 1) % n == 0 if n > 0 else False
+
+ def every_n_iters(self, runner, n):
+ return (runner.iter + 1) % n == 0 if n > 0 else False
+
+ def end_of_epoch(self, runner):
+ return runner.inner_iter + 1 == len(runner.data_loader)
+
+ def is_last_epoch(self, runner):
+ return runner.epoch + 1 == runner._max_epochs
+
+ def is_last_iter(self, runner):
+ return runner.iter + 1 == runner._max_iters
+
+ def get_triggered_stages(self):
+ trigger_stages = set()
+ for stage in Hook.stages:
+ if is_method_overridden(stage, Hook, self):
+ trigger_stages.add(stage)
+
+ # some methods will be triggered in multi stages
+ # use this dict to map method to stages.
+ method_stages_map = {
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
+ }
+
+ for method, map_stages in method_stages_map.items():
+ if is_method_overridden(method, Hook, self):
+ trigger_stages.update(map_stages)
+
+ return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/annotator/uniformer/mmcv/runner/hooks/iter_timer.py b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class IterTimerHook(Hook):
+
+ def before_epoch(self, runner):
+ self.t = time.time()
+
+ def before_iter(self, runner):
+ runner.log_buffer.update({'data_time': time.time() - self.t})
+
+ def after_iter(self, runner):
+ runner.log_buffer.update({'time': time.time() - self.t})
+ self.t = time.time()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
+from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+from .wandb import WandbLoggerHook
+
+__all__ = [
+ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
+]
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/base.py b/annotator/uniformer/mmcv/runner/hooks/logger/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/base.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+import torch
+
+from ..hook import Hook
+
+
+class LoggerHook(Hook):
+ """Base class for logger hooks.
+
+ Args:
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ self.interval = interval
+ self.ignore_last = ignore_last
+ self.reset_flag = reset_flag
+ self.by_epoch = by_epoch
+
+ @abstractmethod
+ def log(self, runner):
+ pass
+
+ @staticmethod
+ def is_scalar(val, include_np=True, include_torch=True):
+ """Tell the input variable is a scalar or not.
+
+ Args:
+ val: Input variable.
+ include_np (bool): Whether include 0-d np.ndarray as a scalar.
+ include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
+
+ Returns:
+ bool: True or False.
+ """
+ if isinstance(val, numbers.Number):
+ return True
+ elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
+ return True
+ elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
+ return True
+ else:
+ return False
+
+ def get_mode(self, runner):
+ if runner.mode == 'train':
+ if 'time' in runner.log_buffer.output:
+ mode = 'train'
+ else:
+ mode = 'val'
+ elif runner.mode == 'val':
+ mode = 'val'
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return mode
+
+ def get_epoch(self, runner):
+ if runner.mode == 'train':
+ epoch = runner.epoch + 1
+ elif runner.mode == 'val':
+ # normal val mode
+ # runner.epoch += 1 has been done before val workflow
+ epoch = runner.epoch
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return epoch
+
+ def get_iter(self, runner, inner_iter=False):
+ """Get the current training iteration step."""
+ if self.by_epoch and inner_iter:
+ current_iter = runner.inner_iter + 1
+ else:
+ current_iter = runner.iter + 1
+ return current_iter
+
+ def get_lr_tags(self, runner):
+ tags = {}
+ lrs = runner.current_lr()
+ if isinstance(lrs, dict):
+ for name, value in lrs.items():
+ tags[f'learning_rate/{name}'] = value[0]
+ else:
+ tags['learning_rate'] = lrs[0]
+ return tags
+
+ def get_momentum_tags(self, runner):
+ tags = {}
+ momentums = runner.current_momentum()
+ if isinstance(momentums, dict):
+ for name, value in momentums.items():
+ tags[f'momentum/{name}'] = value[0]
+ else:
+ tags['momentum'] = momentums[0]
+ return tags
+
+ def get_loggable_tags(self,
+ runner,
+ allow_scalar=True,
+ allow_text=False,
+ add_mode=True,
+ tags_to_skip=('time', 'data_time')):
+ tags = {}
+ for var, val in runner.log_buffer.output.items():
+ if var in tags_to_skip:
+ continue
+ if self.is_scalar(val) and not allow_scalar:
+ continue
+ if isinstance(val, str) and not allow_text:
+ continue
+ if add_mode:
+ var = f'{self.get_mode(runner)}/{var}'
+ tags[var] = val
+ tags.update(self.get_lr_tags(runner))
+ tags.update(self.get_momentum_tags(runner))
+ return tags
+
+ def before_run(self, runner):
+ for hook in runner.hooks[::-1]:
+ if isinstance(hook, LoggerHook):
+ hook.reset_flag = True
+ break
+
+ def before_epoch(self, runner):
+ runner.log_buffer.clear() # clear logs of last epoch
+
+ def after_train_iter(self, runner):
+ if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif self.end_of_epoch(runner) and not self.ignore_last:
+ # not precise but more stable
+ runner.log_buffer.average(self.interval)
+
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_train_epoch(self, runner):
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_val_epoch(self, runner):
+ runner.log_buffer.average()
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class DvcliveLoggerHook(LoggerHook):
+ """Class to log metrics with dvclive.
+
+ It requires `dvclive`_ to be installed.
+
+ Args:
+ path (str): Directory where dvclive will write TSV log files.
+ interval (int): Logging interval (every k iterations).
+ Default 10.
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ Default: True.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ Default: True.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ Default: True.
+
+ .. _dvclive:
+ https://dvc.org/doc/dvclive
+ """
+
+ def __init__(self,
+ path,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ by_epoch=True):
+
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.path = path
+ self.import_dvclive()
+
+ def import_dvclive(self):
+ try:
+ import dvclive
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install dvclive" to install dvclive')
+ self.dvclive = dvclive
+
+ @master_only
+ def before_run(self, runner):
+ self.dvclive.init(self.path)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for k, v in tags.items():
+ self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class MlflowLoggerHook(LoggerHook):
+
+ def __init__(self,
+ exp_name=None,
+ tags=None,
+ log_model=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ """Class to log metrics and (optionally) a trained model to MLflow.
+
+ It requires `MLflow`_ to be installed.
+
+ Args:
+ exp_name (str, optional): Name of the experiment to be used.
+ Default None.
+ If not None, set the active experiment.
+ If experiment does not exist, an experiment with provided name
+ will be created.
+ tags (dict of str: str, optional): Tags for the current run.
+ Default None.
+ If not None, set tags for the current run.
+ log_model (bool, optional): Whether to log an MLflow artifact.
+ Default True.
+ If True, log runner.model as an MLflow artifact
+ for the current run.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _MLflow:
+ https://www.mlflow.org/docs/latest/index.html
+ """
+ super(MlflowLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_mlflow()
+ self.exp_name = exp_name
+ self.tags = tags
+ self.log_model = log_model
+
+ def import_mlflow(self):
+ try:
+ import mlflow
+ import mlflow.pytorch as mlflow_pytorch
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install mlflow" to install mlflow')
+ self.mlflow = mlflow
+ self.mlflow_pytorch = mlflow_pytorch
+
+ @master_only
+ def before_run(self, runner):
+ super(MlflowLoggerHook, self).before_run(runner)
+ if self.exp_name is not None:
+ self.mlflow.set_experiment(self.exp_name)
+ if self.tags is not None:
+ self.mlflow.set_tags(self.tags)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ self.mlflow.log_metrics(tags, step=self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.log_model:
+ self.mlflow_pytorch.log_model(runner.model, 'models')
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class NeptuneLoggerHook(LoggerHook):
+ """Class to log metrics to NeptuneAI.
+
+ It requires `neptune-client` to be installed.
+
+ Args:
+ init_kwargs (dict): a dict contains the initialization keys as below:
+ - project (str): Name of a project in a form of
+ namespace/project_name. If None, the value of
+ NEPTUNE_PROJECT environment variable will be taken.
+ - api_token (str): User’s API token.
+ If None, the value of NEPTUNE_API_TOKEN environment
+ variable will be taken. Note: It is strongly recommended
+ to use NEPTUNE_API_TOKEN environment variable rather than
+ placing your API token in plain text in your source code.
+ - name (str, optional, default is 'Untitled'): Editable name of
+ the run. Name is displayed in the run's Details and in
+ Runs table as a column.
+ Check https://docs.neptune.ai/api-reference/neptune#init for
+ more init arguments.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _NeptuneAI:
+ https://docs.neptune.ai/you-should-know/logging-metadata
+ """
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ with_step=True,
+ by_epoch=True):
+
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_neptune()
+ self.init_kwargs = init_kwargs
+ self.with_step = with_step
+
+ def import_neptune(self):
+ try:
+ import neptune.new as neptune
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install neptune-client" to install neptune')
+ self.neptune = neptune
+ self.run = None
+
+ @master_only
+ def before_run(self, runner):
+ if self.init_kwargs:
+ self.run = self.neptune.init(**self.init_kwargs)
+ else:
+ self.run = self.neptune.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for tag_name, tag_value in tags.items():
+ if self.with_step:
+ self.run[tag_name].log(
+ tag_value, step=self.get_iter(runner))
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.run[tag_name].log(tags)
+
+ @master_only
+ def after_run(self, runner):
+ self.run.stop()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dcf146d8163aff1363e9764999b0a74d674a595
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import torch
+import yaml
+
+import annotator.uniformer.mmcv as mmcv
+from ....parallel.utils import is_module_wrapper
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class PaviLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ add_graph=False,
+ add_last_ckpt=False,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True,
+ img_key='img_info'):
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.init_kwargs = init_kwargs
+ self.add_graph = add_graph
+ self.add_last_ckpt = add_last_ckpt
+ self.img_key = img_key
+
+ @master_only
+ def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
+ try:
+ from pavi import SummaryWriter
+ except ImportError:
+ raise ImportError('Please run "pip install pavi" to install pavi.')
+
+ self.run_name = runner.work_dir.split('/')[-1]
+
+ if not self.init_kwargs:
+ self.init_kwargs = dict()
+ self.init_kwargs['name'] = self.run_name
+ self.init_kwargs['model'] = runner._model_name
+ if runner.meta is not None:
+ if 'config_dict' in runner.meta:
+ config_dict = runner.meta['config_dict']
+ assert isinstance(
+ config_dict,
+ dict), ('meta["config_dict"] has to be of a dict, '
+ f'but got {type(config_dict)}')
+ elif 'config_file' in runner.meta:
+ config_file = runner.meta['config_file']
+ config_dict = dict(mmcv.Config.fromfile(config_file))
+ else:
+ config_dict = None
+ if config_dict is not None:
+ # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
+ # to properly set up the progress bar.
+ config_dict = config_dict.copy()
+ config_dict.setdefault('max_iter', runner.max_iters)
+ # non-serializable values are first converted in
+ # mmcv.dump to json
+ config_dict = json.loads(
+ mmcv.dump(config_dict, file_format='json'))
+ session_text = yaml.dump(config_dict)
+ self.init_kwargs['session_text'] = session_text
+ self.writer = SummaryWriter(**self.init_kwargs)
+
+ def get_step(self, runner):
+ """Get the total training step/epoch."""
+ if self.get_mode(runner) == 'val' and self.by_epoch:
+ return self.get_epoch(runner)
+ else:
+ return self.get_iter(runner)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, add_mode=False)
+ if tags:
+ self.writer.add_scalars(
+ self.get_mode(runner), tags, self.get_step(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.add_last_ckpt:
+ ckpt_path = osp.join(runner.work_dir, 'latest.pth')
+ if osp.islink(ckpt_path):
+ ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
+
+ if osp.isfile(ckpt_path):
+ # runner.epoch += 1 has been done before `after_run`.
+ iteration = runner.epoch if self.by_epoch else runner.iter
+ return self.writer.add_snapshot_file(
+ tag=self.run_name,
+ snapshot_file_path=ckpt_path,
+ iteration=iteration)
+
+ # flush the buffer and send a task ending signal to Pavi
+ self.writer.close()
+
+ @master_only
+ def before_epoch(self, runner):
+ if runner.epoch == 0 and self.add_graph:
+ if is_module_wrapper(runner.model):
+ _model = runner.model.module
+ else:
+ _model = runner.model
+ device = next(_model.parameters()).device
+ data = next(iter(runner.data_loader))
+ image = data[self.img_key][0:1].to(device)
+ with torch.no_grad():
+ self.writer.add_graph(_model, image)
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd5011dc08def6c09eef86d3ce5b124c9fc5372
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TensorboardLoggerHook(LoggerHook):
+
+ def __init__(self,
+ log_dir=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.log_dir = log_dir
+
+ @master_only
+ def before_run(self, runner):
+ super(TensorboardLoggerHook, self).before_run(runner)
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.1')):
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError('Please install tensorboardX to use '
+ 'TensorboardLoggerHook.')
+ else:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ 'the dependencies to use torch.utils.tensorboard '
+ '(applicable to PyTorch 1.1 or higher)')
+
+ if self.log_dir is None:
+ self.log_dir = osp.join(runner.work_dir, 'tf_logs')
+ self.writer = SummaryWriter(self.log_dir)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, allow_text=True)
+ for tag, val in tags.items():
+ if isinstance(val, str):
+ self.writer.add_text(tag, val, self.get_iter(runner))
+ else:
+ self.writer.add_scalar(tag, val, self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ self.writer.close()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/text.py b/annotator/uniformer/mmcv/runner/hooks/logger/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b1a3eca9595a130121526f8b4c29915387ab35
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/text.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio.file_client import FileClient
+from annotator.uniformer.mmcv.utils import is_tuple_of, scandir
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TextLoggerHook(LoggerHook):
+ """Logger hook in text.
+
+ In this logger hook, the information will be printed on terminal and
+ saved in json file.
+
+ Args:
+ by_epoch (bool, optional): Whether EpochBasedRunner is used.
+ Default: True.
+ interval (int, optional): Logging interval (every k iterations).
+ Default: 10.
+ ignore_last (bool, optional): Ignore the log of last iterations in each
+ epoch if less than :attr:`interval`. Default: True.
+ reset_flag (bool, optional): Whether to clear the output buffer after
+ logging. Default: False.
+ interval_exp_name (int, optional): Logging interval for experiment
+ name. This feature is to help users conveniently get the experiment
+ information from screen or log file. Default: 1000.
+ out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
+ If ``out_dir`` is specified, logs will be copied to a new directory
+ which is the concatenation of ``out_dir`` and the last level
+ directory of ``runner.work_dir``. Default: None.
+ `New in version 1.3.16.`
+ out_suffix (str or tuple[str], optional): Those filenames ending with
+ ``out_suffix`` will be copied to ``out_dir``.
+ Default: ('.log.json', '.log', '.py').
+ `New in version 1.3.16.`
+ keep_local (bool, optional): Whether to keep local log when
+ :attr:`out_dir` is specified. If False, the local log will be
+ removed. Default: True.
+ `New in version 1.3.16.`
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ interval_exp_name=1000,
+ out_dir=None,
+ out_suffix=('.log.json', '.log', '.py'),
+ keep_local=True,
+ file_client_args=None):
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.by_epoch = by_epoch
+ self.time_sec_tot = 0
+ self.interval_exp_name = interval_exp_name
+
+ if out_dir is None and file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" when `out_dir` is not'
+ 'specified.')
+ self.out_dir = out_dir
+
+ if not (out_dir is None or isinstance(out_dir, str)
+ or is_tuple_of(out_dir, str)):
+ raise TypeError('out_dir should be "None" or string or tuple of '
+ 'string, but got {out_dir}')
+ self.out_suffix = out_suffix
+
+ self.keep_local = keep_local
+ self.file_client_args = file_client_args
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(file_client_args,
+ self.out_dir)
+
+ def before_run(self, runner):
+ super(TextLoggerHook, self).before_run(runner)
+
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # The final `self.out_dir` is the concatenation of `self.out_dir`
+ # and the last level directory of `runner.work_dir`
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'Text logs will be saved to {self.out_dir} by '
+ f'{self.file_client.name} after the training process.'))
+
+ self.start_iter = runner.iter
+ self.json_log_path = osp.join(runner.work_dir,
+ f'{runner.timestamp}.log.json')
+ if runner.meta is not None:
+ self._dump_log(runner.meta, runner)
+
+ def _get_max_memory(self, runner):
+ device = getattr(runner.model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ if runner.world_size > 1:
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+ return mem_mb.item()
+
+ def _log_info(self, log_dict, runner):
+ # print exp name for users to distinguish experiments
+ # at every ``interval_exp_name`` iterations and the end of each epoch
+ if runner.meta is not None and 'exp_name' in runner.meta:
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
+ self.by_epoch and self.end_of_epoch(runner)):
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
+ runner.logger.info(exp_info)
+
+ if log_dict['mode'] == 'train':
+ if isinstance(log_dict['lr'], dict):
+ lr_str = []
+ for k, val in log_dict['lr'].items():
+ lr_str.append(f'lr_{k}: {val:.3e}')
+ lr_str = ' '.join(lr_str)
+ else:
+ lr_str = f'lr: {log_dict["lr"]:.3e}'
+
+ # by epoch: Epoch [4][100/1000]
+ # by iter: Iter [100/100000]
+ if self.by_epoch:
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
+ else:
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
+ log_str += f'{lr_str}, '
+
+ if 'time' in log_dict.keys():
+ self.time_sec_tot += (log_dict['time'] * self.interval)
+ time_sec_avg = self.time_sec_tot / (
+ runner.iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ log_str += f'eta: {eta_str}, '
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
+ f'data_time: {log_dict["data_time"]:.3f}, '
+ # statistic memory
+ if torch.cuda.is_available():
+ log_str += f'memory: {log_dict["memory"]}, '
+ else:
+ # val/test time
+ # here 1000 is the length of the val dataloader
+ # by epoch: Epoch[val] [4][1000]
+ # by iter: Iter[val] [1000]
+ if self.by_epoch:
+ log_str = f'Epoch({log_dict["mode"]}) ' \
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
+ else:
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
+
+ log_items = []
+ for name, val in log_dict.items():
+ # TODO: resolve this hack
+ # these items have been in log_str
+ if name in [
+ 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
+ 'memory', 'epoch'
+ ]:
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ', '.join(log_items)
+
+ runner.logger.info(log_str)
+
+ def _dump_log(self, log_dict, runner):
+ # dump log in json format
+ json_log = OrderedDict()
+ for k, v in log_dict.items():
+ json_log[k] = self._round_float(v)
+ # only append log at last line
+ if runner.rank == 0:
+ with open(self.json_log_path, 'a+') as f:
+ mmcv.dump(json_log, f, file_format='json')
+ f.write('\n')
+
+ def _round_float(self, items):
+ if isinstance(items, list):
+ return [self._round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 5)
+ else:
+ return items
+
+ def log(self, runner):
+ if 'eval_iter_num' in runner.log_buffer.output:
+ # this doesn't modify runner.iter and is regardless of by_epoch
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
+ else:
+ cur_iter = self.get_iter(runner, inner_iter=True)
+
+ log_dict = OrderedDict(
+ mode=self.get_mode(runner),
+ epoch=self.get_epoch(runner),
+ iter=cur_iter)
+
+ # only record lr of the first param group
+ cur_lr = runner.current_lr()
+ if isinstance(cur_lr, list):
+ log_dict['lr'] = cur_lr[0]
+ else:
+ assert isinstance(cur_lr, dict)
+ log_dict['lr'] = {}
+ for k, lr_ in cur_lr.items():
+ assert isinstance(lr_, list)
+ log_dict['lr'].update({k: lr_[0]})
+
+ if 'time' in runner.log_buffer.output:
+ # statistic memory
+ if torch.cuda.is_available():
+ log_dict['memory'] = self._get_max_memory(runner)
+
+ log_dict = dict(log_dict, **runner.log_buffer.output)
+
+ self._log_info(log_dict, runner)
+ self._dump_log(log_dict, runner)
+ return log_dict
+
+ def after_run(self, runner):
+ # copy or upload logs to self.out_dir
+ if self.out_dir is not None:
+ for filename in scandir(runner.work_dir, self.out_suffix, True):
+ local_filepath = osp.join(runner.work_dir, filename)
+ out_filepath = self.file_client.join_path(
+ self.out_dir, filename)
+ with open(local_filepath, 'r') as f:
+ self.file_client.put_text(f.read(), out_filepath)
+
+ runner.logger.info(
+ (f'The file {local_filepath} has been uploaded to '
+ f'{out_filepath}.'))
+
+ if not self.keep_local:
+ os.remove(local_filepath)
+ runner.logger.info(
+ (f'{local_filepath} was removed due to the '
+ '`self.keep_local=False`'))
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class WandbLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ commit=True,
+ by_epoch=True,
+ with_step=True):
+ super(WandbLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_wandb()
+ self.init_kwargs = init_kwargs
+ self.commit = commit
+ self.with_step = with_step
+
+ def import_wandb(self):
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install wandb" to install wandb')
+ self.wandb = wandb
+
+ @master_only
+ def before_run(self, runner):
+ super(WandbLoggerHook, self).before_run(runner)
+ if self.wandb is None:
+ self.import_wandb()
+ if self.init_kwargs:
+ self.wandb.init(**self.init_kwargs)
+ else:
+ self.wandb.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ if self.with_step:
+ self.wandb.log(
+ tags, step=self.get_iter(runner), commit=self.commit)
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.wandb.log(tags, commit=self.commit)
+
+ @master_only
+ def after_run(self, runner):
+ self.wandb.join()
diff --git a/annotator/uniformer/mmcv/runner/hooks/lr_updater.py b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..6365908ddf6070086de2ffc0afada46ed2f32256
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from math import cos, pi
+
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+
+
+class LrUpdaterHook(Hook):
+ """LR Scheduler in MMCV.
+
+ Args:
+ by_epoch (bool): LR changes epoch by epoch
+ warmup (string): Type of warmup used. It can be None(use no warmup),
+ 'constant', 'linear' or 'exp'
+ warmup_iters (int): The number of iterations or epochs that warmup
+ lasts
+ warmup_ratio (float): LR used at the beginning of warmup equals to
+ warmup_ratio * initial_lr
+ warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
+ means the number of epochs that warmup lasts, otherwise means the
+ number of iteration that warmup lasts
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.1,
+ warmup_by_epoch=False):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_ratio" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.warmup_by_epoch = warmup_by_epoch
+
+ if self.warmup_by_epoch:
+ self.warmup_epochs = self.warmup_iters
+ self.warmup_iters = None
+ else:
+ self.warmup_epochs = None
+
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+
+ def _set_lr(self, runner, lr_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+ param_group['lr'] = lr
+ else:
+ for param_group, lr in zip(runner.optimizer.param_groups,
+ lr_groups):
+ param_group['lr'] = lr
+
+ def get_lr(self, runner, base_lr):
+ raise NotImplementedError
+
+ def get_regular_lr(self, runner):
+ if isinstance(runner.optimizer, dict):
+ lr_groups = {}
+ for k in runner.optimizer.keys():
+ _lr_group = [
+ self.get_lr(runner, _base_lr)
+ for _base_lr in self.base_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+
+ return lr_groups
+ else:
+ return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
+
+ def get_warmup_lr(self, cur_iters):
+
+ def _get_warmup_lr(cur_iters, regular_lr):
+ if self.warmup == 'constant':
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+
+ if isinstance(self.regular_lr, dict):
+ lr_groups = {}
+ for key, regular_lr in self.regular_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.regular_lr)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ _base_lr = [
+ group['initial_lr'] for group in optim.param_groups
+ ]
+ self.base_lr.update({k: _base_lr})
+ else:
+ for group in runner.optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ self.base_lr = [
+ group['initial_lr'] for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if self.warmup_iters is None:
+ epoch_len = len(runner.data_loader)
+ self.warmup_iters = self.warmup_epochs * epoch_len
+
+ if not self.by_epoch:
+ return
+
+ self.regular_lr = self.get_regular_lr(runner)
+ self._set_lr(runner, self.regular_lr)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_lr = self.get_regular_lr(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+
+
+@HOOKS.register_module()
+class FixedLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, **kwargs):
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ return base_lr
+
+
+@HOOKS.register_module()
+class StepLrUpdaterHook(LrUpdaterHook):
+ """Step LR scheduler with min_lr clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the LR. If an int value is given,
+ regard it as the decay interval. If a list is given, decay LR at
+ these steps.
+ gamma (float, optional): Decay LR ratio. Default: 0.1.
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
+ is lower than `min_lr`, it will be clipped to this value. If None
+ is given, we don't perform lr clipping. Default: None.
+ """
+
+ def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_lr = min_lr
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ lr = base_lr * (self.gamma**exp)
+ if self.min_lr is not None:
+ # clip to a minimum value
+ lr = max(lr, self.min_lr)
+ return lr
+
+
+@HOOKS.register_module()
+class ExpLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, **kwargs):
+ self.gamma = gamma
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * self.gamma**progress
+
+
+@HOOKS.register_module()
+class PolyLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, power=1., min_lr=0., **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ coeff = (1 - progress / max_progress)**self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+@HOOKS.register_module()
+class InvLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, power=1., **kwargs):
+ self.gamma = gamma
+ self.power = power
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * (1 + self.gamma * progress)**(-self.power)
+
+
+@HOOKS.register_module()
+class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
+ """Flat + Cosine lr schedule.
+
+ Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
+
+ Args:
+ start_percent (float): When to start annealing the learning rate
+ after the percentage of the total training steps.
+ The value should be in range [0, 1).
+ Default: 0.75
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ start_percent=0.75,
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ if start_percent < 0 or start_percent > 1 or not isinstance(
+ start_percent, float):
+ raise ValueError(
+ 'expected float between 0 and 1 start_percent, but '
+ f'got {start_percent}')
+ self.start_percent = start_percent
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ start = round(runner.max_epochs * self.start_percent)
+ progress = runner.epoch - start
+ max_progress = runner.max_epochs - start
+ else:
+ start = round(runner.max_iters * self.start_percent)
+ progress = runner.iter - start
+ max_progress = runner.max_iters - start
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ if progress < 0:
+ return base_lr
+ else:
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class CosineRestartLrUpdaterHook(LrUpdaterHook):
+ """Cosine annealing with restarts learning rate scheme.
+
+ Args:
+ periods (list[int]): Periods for each cosine anneling cycle.
+ restart_weights (list[float], optional): Restart weights at each
+ restart iteration. Default: [1].
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ periods,
+ restart_weights=[1],
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.periods = periods
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ self.restart_weights = restart_weights
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
+
+ self.cumulative_periods = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ else:
+ progress = runner.iter
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ idx = get_position_from_periods(progress, self.cumulative_periods)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+ current_periods = self.periods[idx]
+
+ alpha = min((progress - nearest_restart) / current_periods, 1)
+ return annealing_cos(base_lr, target_lr, alpha, current_weight)
+
+
+def get_position_from_periods(iteration, cumulative_periods):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_periods = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 3.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_periods (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_periods):
+ if iteration < period:
+ return i
+ raise ValueError(f'Current iteration {iteration} exceeds '
+ f'cumulative_periods {cumulative_periods}')
+
+
+@HOOKS.register_module()
+class CyclicLrUpdaterHook(LrUpdaterHook):
+ """Cyclic LR Scheduler.
+
+ Implement the cyclical learning rate policy (CLR) described in
+ https://arxiv.org/pdf/1506.01186.pdf
+
+ Different from the original paper, we use cosine annealing rather than
+ triangular policy inside a cycle. This improves the performance in the
+ 3D detection area.
+
+ Args:
+ by_epoch (bool): Whether to update LR by epoch.
+ target_ratio (tuple[float]): Relative ratio of the highest LR and the
+ lowest LR to the initial LR.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of LR in
+ the total cycle.
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing. Default: 'cos'.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(10, 1e-4),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ anneal_strategy='cos',
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.lr_phases = [] # init lr_phases
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicLrUpdaterHook, self).before_run(runner)
+ # initiate lr_phases
+ # total lr_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.lr_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.lr_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.lr_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return self.anneal_func(base_lr * start_ratio,
+ base_lr * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleLrUpdaterHook(LrUpdaterHook):
+ """One Cycle LR Scheduler.
+
+ The 1cycle learning rate policy changes the learning rate after every
+ batch. The one cycle learning rate policy is described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ Args:
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group.
+ total_steps (int, optional): The total number of steps in the cycle.
+ Note that if a value is not provided here, it will be the max_iter
+ of runner. Default: None.
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ div_factor (float): Determines the initial learning rate via
+ initial_lr = max_lr/div_factor
+ Default: 25
+ final_div_factor (float): Determines the minimum learning rate via
+ min_lr = initial_lr/final_div_factor
+ Default: 1e4
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25,
+ final_div_factor=1e4,
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch = False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(max_lr, (numbers.Number, list, dict)):
+ raise ValueError('the type of max_lr must be the one of list or '
+ f'dict, but got {type(max_lr)}')
+ self._max_lr = max_lr
+ if total_steps is not None:
+ if not isinstance(total_steps, int):
+ raise ValueError('the type of total_steps must be int, but'
+ f'got {type(total_steps)}')
+ self.total_steps = total_steps
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.div_factor = div_factor
+ self.final_div_factor = final_div_factor
+ self.three_phase = three_phase
+ self.lr_phases = [] # init lr_phases
+ super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if hasattr(self, 'total_steps'):
+ total_steps = self.total_steps
+ else:
+ total_steps = runner.max_iters
+ if total_steps < runner.max_iters:
+ raise ValueError(
+ 'The total steps must be greater than or equal to max '
+ f'iterations {runner.max_iters} of runner, but total steps '
+ f'is {total_steps}.')
+
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ _max_lr = format_param(k, optim, self._max_lr)
+ self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(optim.param_groups, self.base_lr[k]):
+ group.setdefault('initial_lr', lr)
+ else:
+ k = type(runner.optimizer).__name__
+ _max_lr = format_param(k, runner.optimizer, self._max_lr)
+ self.base_lr = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
+ group.setdefault('initial_lr', lr)
+
+ if self.three_phase:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append([
+ float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
+ ])
+ self.lr_phases.append(
+ [total_steps - 1, 1, 1 / self.final_div_factor])
+ else:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append(
+ [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
+ if curr_iter <= end_iter:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
+ pct)
+ break
+ start_iter = end_iter
+ return lr
+
+
+def annealing_cos(start, end, factor, weight=1):
+ """Calculate annealing cos learning rate.
+
+ Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
+ percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the cosine annealing.
+ end (float): The ending learing rate of the cosine annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ weight (float, optional): The combination factor of `start` and `end`
+ when calculating the actual starting learning rate. Default to 1.
+ """
+ cos_out = cos(pi * factor) + 1
+ return end + 0.5 * weight * (start - end) * cos_out
+
+
+def annealing_linear(start, end, factor):
+ """Calculate annealing linear learning rate.
+
+ Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the linear annealing.
+ end (float): The ending learing rate of the linear annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ """
+ return start + (end - start) * factor
+
+
+def format_param(name, optim, param):
+ if isinstance(param, numbers.Number):
+ return [param] * len(optim.param_groups)
+ elif isinstance(param, (list, tuple)): # multi param groups
+ if len(param) != len(optim.param_groups):
+ raise ValueError(f'expected {len(optim.param_groups)} '
+ f'values for {name}, got {len(param)}')
+ return param
+ else: # multi optimizers
+ if name not in param:
+ raise KeyError(f'{name} is not found in {param.keys()}')
+ return param[name]
diff --git a/annotator/uniformer/mmcv/runner/hooks/memory.py b/annotator/uniformer/mmcv/runner/hooks/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/memory.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EmptyCacheHook(Hook):
+
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+ self._before_epoch = before_epoch
+ self._after_epoch = after_epoch
+ self._after_iter = after_iter
+
+ def after_iter(self, runner):
+ if self._after_iter:
+ torch.cuda.empty_cache()
+
+ def before_epoch(self, runner):
+ if self._before_epoch:
+ torch.cuda.empty_cache()
+
+ def after_epoch(self, runner):
+ if self._after_epoch:
+ torch.cuda.empty_cache()
diff --git a/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..60437756ceedf06055ec349df69a25465738d3f0
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+from .lr_updater import annealing_cos, annealing_linear, format_param
+
+
+class MomentumUpdaterHook(Hook):
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.9):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_momentum" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+
+ self.base_momentum = [] # initial momentum for all param groups
+ self.regular_momentum = [
+ ] # expected momentum if no warming up is performed
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, base_momentum):
+ raise NotImplementedError
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k in runner.optimizer.keys():
+ _momentum_group = [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum[k]
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ return [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum
+ ]
+
+ def get_warmup_momentum(self, cur_iters):
+
+ def _get_warmup_momentum(cur_iters, regular_momentum):
+ if self.warmup == 'constant':
+ warmup_momentum = [
+ _momentum / self.warmup_ratio
+ for _momentum in self.regular_momentum
+ ]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_momentum = [
+ _momentum / (1 - k) for _momentum in self.regular_mom
+ ]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_momentum = [
+ _momentum / k for _momentum in self.regular_mom
+ ]
+ return warmup_momentum
+
+ if isinstance(self.regular_momentum, dict):
+ momentum_groups = {}
+ for key, regular_momentum in self.regular_momentum.items():
+ momentum_groups[key] = _get_warmup_momentum(
+ cur_iters, regular_momentum)
+ return momentum_groups
+ else:
+ return _get_warmup_momentum(cur_iters, self.regular_momentum)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint,
+ # if 'initial_momentum' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_momentum = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ _base_momentum = [
+ group['initial_momentum'] for group in optim.param_groups
+ ]
+ self.base_momentum.update({k: _base_momentum})
+ else:
+ for group in runner.optimizer.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ self.base_momentum = [
+ group['initial_momentum']
+ for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ self.regular_mom = self.get_regular_momentum(runner)
+ self._set_momentum(runner, self.regular_mom)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_mom = self.get_regular_momentum(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+
+
+@HOOKS.register_module()
+class StepMomentumUpdaterHook(MomentumUpdaterHook):
+ """Step momentum scheduler with min value clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the momentum. If an int value is
+ given, regard it as the decay interval. If a list is given, decay
+ momentum at these steps.
+ gamma (float, optional): Decay momentum ratio. Default: 0.5.
+ min_momentum (float, optional): Minimum momentum value to keep. If
+ momentum after decay is lower than this value, it will be clipped
+ accordingly. If None is given, we don't perform lr clipping.
+ Default: None.
+ """
+
+ def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_momentum = min_momentum
+ super(StepMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ momentum = base_momentum * (self.gamma**exp)
+ if self.min_momentum is not None:
+ # clip to a minimum value
+ momentum = max(momentum, self.min_momentum)
+ return momentum
+
+
+@HOOKS.register_module()
+class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
+
+ def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
+ assert (min_momentum is None) ^ (min_momentum_ratio is None)
+ self.min_momentum = min_momentum
+ self.min_momentum_ratio = min_momentum_ratio
+ super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ if self.min_momentum_ratio is not None:
+ target_momentum = base_momentum * self.min_momentum_ratio
+ else:
+ target_momentum = self.min_momentum
+ return annealing_cos(base_momentum, target_momentum,
+ progress / max_progress)
+
+
+@HOOKS.register_module()
+class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
+ """Cyclic momentum Scheduler.
+
+ Implement the cyclical momentum scheduler policy described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ This momentum scheduler usually used together with the CyclicLRUpdater
+ to improve the performance in the 3D detection area.
+
+ Attributes:
+ target_ratio (tuple[float]): Relative ratio of the lowest momentum and
+ the highest momentum to the initial momentum.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of momentum
+ in the total cycle.
+ by_epoch (bool): Whether to update momentum by epoch.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(0.85 / 0.95, 1),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.momentum_phases = [] # init momentum_phases
+ # currently only support by_epoch=False
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicMomentumUpdaterHook, self).before_run(runner)
+ # initiate momentum_phases
+ # total momentum_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.momentum_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.momentum_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_momentum(self, runner, base_momentum):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.momentum_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return annealing_cos(base_momentum * start_ratio,
+ base_momentum * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
+ """OneCycle momentum Scheduler.
+
+ This momentum scheduler usually used together with the OneCycleLrUpdater
+ to improve the performance.
+
+ Args:
+ base_momentum (float or list): Lower momentum boundaries in the cycle
+ for each parameter group. Note that momentum is cycled inversely
+ to learning rate; at the peak of a cycle, momentum is
+ 'base_momentum' and learning rate is 'max_lr'.
+ Default: 0.85
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ Note that momentum is cycled inversely
+ to learning rate; at the start of a cycle, momentum is
+ 'max_momentum' and learning rate is 'base_lr'
+ Default: 0.95
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch=False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(base_momentum, (float, list, dict)):
+ raise ValueError('base_momentum must be the type among of float,'
+ 'list or dict.')
+ self._base_momentum = base_momentum
+ if not isinstance(max_momentum, (float, list, dict)):
+ raise ValueError('max_momentum must be the type among of float,'
+ 'list or dict.')
+ self._max_momentum = max_momentum
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('Expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must by one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.three_phase = three_phase
+ self.momentum_phases = [] # init momentum_phases
+ super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(
+ optim.param_groups, _base_momentum, _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+ else:
+ optim = runner.optimizer
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ k = type(optim).__name__
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(optim.param_groups,
+ _base_momentum,
+ _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+
+ if self.three_phase:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter':
+ float(2 * self.pct_start * runner.max_iters) - 2,
+ 'start_momentum':
+ 'base_momentum',
+ 'end_momentum':
+ 'max_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'max_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+ else:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'base_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, param_group):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, phase in enumerate(self.momentum_phases):
+ end_iter = phase['end_iter']
+ if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ momentum = self.anneal_func(
+ param_group[phase['start_momentum']],
+ param_group[phase['end_momentum']], pct)
+ break
+ start_iter = end_iter
+ return momentum
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k, optim in runner.optimizer.items():
+ _momentum_group = [
+ self.get_momentum(runner, param_group)
+ for param_group in optim.param_groups
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ momentum_groups = []
+ for param_group in runner.optimizer.param_groups:
+ momentum_groups.append(self.get_momentum(runner, param_group))
+ return momentum_groups
diff --git a/annotator/uniformer/mmcv/runner/hooks/optimizer.py b/annotator/uniformer/mmcv/runner/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef3e9ff8f9c6926e32bdf027612267b64ed80df
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/optimizer.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+from itertools import chain
+
+from torch.nn.utils import clip_grad
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+from ..dist_utils import allreduce_grads
+from ..fp16_utils import LossScaler, wrap_fp16_model
+from .hook import HOOKS, Hook
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ from torch.cuda.amp import GradScaler
+except ImportError:
+ pass
+
+
+@HOOKS.register_module()
+class OptimizerHook(Hook):
+
+ def __init__(self, grad_clip=None):
+ self.grad_clip = grad_clip
+
+ def clip_grads(self, params):
+ params = list(
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
+ if len(params) > 0:
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs['loss'].backward()
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+
+
+@HOOKS.register_module()
+class GradientCumulativeOptimizerHook(OptimizerHook):
+ """Optimizer Hook implements multi-iters gradient cumulating.
+
+ Args:
+ cumulative_iters (int, optional): Num of gradient cumulative iters.
+ The optimizer will step every `cumulative_iters` iters.
+ Defaults to 1.
+
+ Examples:
+ >>> # Use cumulative_iters to simulate a large batch size
+ >>> # It is helpful when the hardware cannot handle a large batch size.
+ >>> loader = DataLoader(data, batch_size=64)
+ >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+ >>> # almost equals to
+ >>> loader = DataLoader(data, batch_size=256)
+ >>> optim_hook = OptimizerHook()
+ """
+
+ def __init__(self, cumulative_iters=1, **kwargs):
+ super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+
+ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+ f'cumulative_iters only accepts positive int, but got ' \
+ f'{type(cumulative_iters)} instead.'
+
+ self.cumulative_iters = cumulative_iters
+ self.divisible_iters = 0
+ self.remainder_iters = 0
+ self.initialized = False
+
+ def has_batch_norm(self, module):
+ if isinstance(module, _BatchNorm):
+ return True
+ for m in module.children():
+ if self.has_batch_norm(m):
+ return True
+ return False
+
+ def _init(self, runner):
+ if runner.iter % self.cumulative_iters != 0:
+ runner.logger.warning(
+ 'Resume iter number is not divisible by cumulative_iters in '
+ 'GradientCumulativeOptimizerHook, which means the gradient of '
+ 'some iters is lost and the result may be influenced slightly.'
+ )
+
+ if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+ runner.logger.warning(
+ 'GradientCumulativeOptimizerHook may slightly decrease '
+ 'performance if the model has BatchNorm layers.')
+
+ residual_iters = runner.max_iters - runner.iter
+
+ self.divisible_iters = (
+ residual_iters // self.cumulative_iters * self.cumulative_iters)
+ self.remainder_iters = residual_iters - self.divisible_iters
+
+ self.initialized = True
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
+
+
+if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (using PyTorch's implementation).
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of GradScalar.
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+ implementation of GradScaler. If you use a dict version of
+ loss_scale to create GradScaler, please refer to:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+ for the parameters.
+
+ Examples:
+ >>> loss_scale = dict(
+ ... init_scale=65536.0,
+ ... growth_factor=2.0,
+ ... backoff_factor=0.5,
+ ... growth_interval=2000
+ ... )
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ self._scale_update_param = None
+ if loss_scale == 'dynamic':
+ self.loss_scaler = GradScaler()
+ elif isinstance(loss_scale, float):
+ self._scale_update_param = loss_scale
+ self.loss_scaler = GradScaler(init_scale=loss_scale)
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = GradScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training."""
+ # wrap model mode to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer to
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients.
+ 3. Unscale the optimizer’s gradient tensors.
+ 4. Call optimizer.step() and update scale factor.
+ 5. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+ self.loss_scaler.scale(runner.outputs['loss']).backward()
+ self.loss_scaler.unscale_(runner.optimizer)
+ # grad clip
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using PyTorch's implementation) implements
+ multi-iters gradient cumulating.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ self.loss_scaler.scale(loss).backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ self.loss_scaler.unscale_(runner.optimizer)
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+else:
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (mmcv's implementation).
+
+ The steps of fp16 optimizer is as follows.
+ 1. Scale the loss value.
+ 2. BP in the fp16 model.
+ 2. Copy gradients from fp16 model to fp32 weights.
+ 3. Update fp32 weights.
+ 4. Copy updated parameters from fp32 weights to fp16 model.
+
+ Refer to https://arxiv.org/abs/1710.03740 for more details.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of LossScaler.
+ Defaults to 512.
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ if loss_scale == 'dynamic':
+ self.loss_scaler = LossScaler(mode='dynamic')
+ elif isinstance(loss_scale, float):
+ self.loss_scaler = LossScaler(
+ init_scale=loss_scale, mode='static')
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = LossScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training.
+
+ 1. Make a master copy of fp32 weights for optimization.
+ 2. Convert the main model from fp32 to fp16.
+ """
+ # keep a copy of fp32 weights
+ old_groups = runner.optimizer.param_groups
+ runner.optimizer.param_groups = copy.deepcopy(
+ runner.optimizer.param_groups)
+ state = defaultdict(dict)
+ p_map = {
+ old_p: p
+ for old_p, p in zip(
+ chain(*(g['params'] for g in old_groups)),
+ chain(*(g['params']
+ for g in runner.optimizer.param_groups)))
+ }
+ for k, v in runner.optimizer.state.items():
+ state[p_map[k]] = v
+ runner.optimizer.state = state
+ # convert model to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer `loss_scalar.py`
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients (fp16).
+ 3. Copy gradients from the model to the fp32 weight copy.
+ 4. Scale the gradients back and update the fp32 weight copy.
+ 5. Copy back the params from fp32 weight copy to the fp16 model.
+ 6. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ # scale the loss value
+ scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+ # copy fp16 grads in the model to fp32 params in the optimizer
+
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ self.loss_scaler.update_scale(has_overflow)
+ if has_overflow:
+ runner.logger.warning('Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+ iters gradient cumulating."""
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ # scale the loss value
+ scaled_loss = loss * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ else:
+ runner.logger.warning(
+ 'Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ self.loss_scaler.update_scale(has_overflow)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
diff --git a/annotator/uniformer/mmcv/runner/hooks/profiler.py b/annotator/uniformer/mmcv/runner/hooks/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/profiler.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from ..dist_utils import master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ProfilerHook(Hook):
+ """Profiler to analyze performance during training.
+
+ PyTorch Profiler is a tool that allows the collection of the performance
+ metrics during the training. More details on Profiler can be found at
+ https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
+
+ Args:
+ by_epoch (bool): Profile performance by epoch or by iteration.
+ Default: True.
+ profile_iters (int): Number of iterations for profiling.
+ If ``by_epoch=True``, profile_iters indicates that they are the
+ first profile_iters epochs at the beginning of the
+ training, otherwise it indicates the first profile_iters
+ iterations. Default: 1.
+ activities (list[str]): List of activity groups (CPU, CUDA) to use in
+ profiling. Default: ['cpu', 'cuda'].
+ schedule (dict, optional): Config of generating the callable schedule.
+ if schedule is None, profiler will not add step markers into the
+ trace and table view. Default: None.
+ on_trace_ready (callable, dict): Either a handler or a dict of generate
+ handler. Default: None.
+ record_shapes (bool): Save information about operator's input shapes.
+ Default: False.
+ profile_memory (bool): Track tensor memory allocation/deallocation.
+ Default: False.
+ with_stack (bool): Record source information (file and line number)
+ for the ops. Default: False.
+ with_flops (bool): Use formula to estimate the FLOPS of specific
+ operators (matrix multiplication and 2D convolution).
+ Default: False.
+ json_trace_path (str, optional): Exports the collected trace in Chrome
+ JSON format. Default: None.
+
+ Example:
+ >>> runner = ... # instantiate a Runner
+ >>> # tensorboard trace
+ >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
+ >>> profiler_config = dict(on_trace_ready=trace_config)
+ >>> runner.register_profiler_hook(profiler_config)
+ >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
+ """
+
+ def __init__(self,
+ by_epoch: bool = True,
+ profile_iters: int = 1,
+ activities: List[str] = ['cpu', 'cuda'],
+ schedule: Optional[dict] = None,
+ on_trace_ready: Optional[Union[Callable, dict]] = None,
+ record_shapes: bool = False,
+ profile_memory: bool = False,
+ with_stack: bool = False,
+ with_flops: bool = False,
+ json_trace_path: Optional[str] = None) -> None:
+ try:
+ from torch import profiler # torch version >= 1.8.1
+ except ImportError:
+ raise ImportError('profiler is the new feature of torch1.8.1, '
+ f'but your version is {torch.__version__}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
+ self.by_epoch = by_epoch
+
+ if profile_iters < 1:
+ raise ValueError('profile_iters should be greater than 0, but got '
+ f'{profile_iters}')
+ self.profile_iters = profile_iters
+
+ if not isinstance(activities, list):
+ raise ValueError(
+ f'activities should be list, but got {type(activities)}')
+ self.activities = []
+ for activity in activities:
+ activity = activity.lower()
+ if activity == 'cpu':
+ self.activities.append(profiler.ProfilerActivity.CPU)
+ elif activity == 'cuda':
+ self.activities.append(profiler.ProfilerActivity.CUDA)
+ else:
+ raise ValueError(
+ f'activity should be "cpu" or "cuda", but got {activity}')
+
+ if schedule is not None:
+ self.schedule = profiler.schedule(**schedule)
+ else:
+ self.schedule = None
+
+ self.on_trace_ready = on_trace_ready
+ self.record_shapes = record_shapes
+ self.profile_memory = profile_memory
+ self.with_stack = with_stack
+ self.with_flops = with_flops
+ self.json_trace_path = json_trace_path
+
+ @master_only
+ def before_run(self, runner):
+ if self.by_epoch and runner.max_epochs < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_epochs}')
+
+ if not self.by_epoch and runner.max_iters < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_iters}')
+
+ if callable(self.on_trace_ready): # handler
+ _on_trace_ready = self.on_trace_ready
+ elif isinstance(self.on_trace_ready, dict): # config of handler
+ trace_cfg = self.on_trace_ready.copy()
+ trace_type = trace_cfg.pop('type') # log_trace handler
+ if trace_type == 'log_trace':
+
+ def _log_handler(prof):
+ print(prof.key_averages().table(**trace_cfg))
+
+ _on_trace_ready = _log_handler
+ elif trace_type == 'tb_trace': # tensorboard_trace handler
+ try:
+ import torch_tb_profiler # noqa: F401
+ except ImportError:
+ raise ImportError('please run "pip install '
+ 'torch-tb-profiler" to install '
+ 'torch_tb_profiler')
+ _on_trace_ready = torch.profiler.tensorboard_trace_handler(
+ **trace_cfg)
+ else:
+ raise ValueError('trace_type should be "log_trace" or '
+ f'"tb_trace", but got {trace_type}')
+ elif self.on_trace_ready is None:
+ _on_trace_ready = None # type: ignore
+ else:
+ raise ValueError('on_trace_ready should be handler, dict or None, '
+ f'but got {type(self.on_trace_ready)}')
+
+ if runner.max_epochs > 1:
+ warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
+ 'instead of 1 epoch. Since profiler will slow down '
+ 'the training, it is recommended to train 1 epoch '
+ 'with ProfilerHook and adjust your setting according'
+ ' to the profiler summary. During normal training '
+ '(epoch > 1), you may disable the ProfilerHook.')
+
+ self.profiler = torch.profiler.profile(
+ activities=self.activities,
+ schedule=self.schedule,
+ on_trace_ready=_on_trace_ready,
+ record_shapes=self.record_shapes,
+ profile_memory=self.profile_memory,
+ with_stack=self.with_stack,
+ with_flops=self.with_flops)
+
+ self.profiler.__enter__()
+ runner.logger.info('profiler is profiling...')
+
+ @master_only
+ def after_train_epoch(self, runner):
+ if self.by_epoch and runner.epoch == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
+
+ @master_only
+ def after_train_iter(self, runner):
+ self.profiler.step()
+ if not self.by_epoch and runner.iter == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
diff --git a/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class DistSamplerSeedHook(Hook):
+ """Data-loading sampler for distributed training.
+
+ When distributed training, it is only useful in conjunction with
+ :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
+ purpose with :obj:`IterLoader`.
+ """
+
+ def before_epoch(self, runner):
+ if hasattr(runner.data_loader.sampler, 'set_epoch'):
+ # in case the data loader uses `SequentialSampler` in Pytorch
+ runner.data_loader.sampler.set_epoch(runner.epoch)
+ elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+ # batch sampler in pytorch warps the sampler as its attributes.
+ runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..dist_utils import allreduce_params
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class SyncBuffersHook(Hook):
+ """Synchronize model buffers such as running_mean and running_var in BN at
+ the end of each epoch.
+
+ Args:
+ distributed (bool): Whether distributed training is used. It is
+ effective only for distributed training. Defaults to True.
+ """
+
+ def __init__(self, distributed=True):
+ self.distributed = distributed
+
+ def after_epoch(self, runner):
+ """All-reduce model buffers at the end of each epoch."""
+ if self.distributed:
+ allreduce_params(runner.model.buffers())
diff --git a/annotator/uniformer/mmcv/runner/iter_based_runner.py b/annotator/uniformer/mmcv/runner/iter_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df4de8c0285669dec9b014dfd1f3dd1600f0831
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/iter_based_runner.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .hooks import IterTimerHook
+from .utils import get_host_info
+
+
+class IterLoader:
+
+ def __init__(self, dataloader):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._epoch = 0
+
+ @property
+ def epoch(self):
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, 'set_epoch'):
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __len__(self):
+ return len(self._dataloader)
+
+
+@RUNNERS.register_module()
+class IterBasedRunner(BaseRunner):
+ """Iteration-based Runner.
+
+ This runner train models iteration by iteration.
+ """
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._epoch = data_loader.epoch
+ data_batch = next(data_loader)
+ self.call_hook('before_train_iter')
+ outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.train_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_train_iter')
+ self._inner_iter += 1
+ self._iter += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ data_batch = next(data_loader)
+ self.call_hook('before_val_iter')
+ outputs = self.model.val_step(data_batch, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.val_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_val_iter')
+ self._inner_iter += 1
+
+ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, iters) to specify the
+ running order and iterations. E.g, [('train', 10000),
+ ('val', 1000)] means running 10000 iterations for training and
+ 1000 iterations for validation, iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_iters is not None:
+ warnings.warn(
+ 'setting max_iters in run is deprecated, '
+ 'please set max_iters in runner_config', DeprecationWarning)
+ self._max_iters = max_iters
+ assert self._max_iters is not None, (
+ 'max_iters must be specified during instantiation')
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d iters', workflow,
+ self._max_iters)
+ self.call_hook('before_run')
+
+ iter_loaders = [IterLoader(x) for x in data_loaders]
+
+ self.call_hook('before_epoch')
+
+ while self.iter < self._max_iters:
+ for i, flow in enumerate(workflow):
+ self._inner_iter = 0
+ mode, iters = flow
+ if not isinstance(mode, str) or not hasattr(self, mode):
+ raise ValueError(
+ 'runner has no method named "{}" to run a workflow'.
+ format(mode))
+ iter_runner = getattr(self, mode)
+ for _ in range(iters):
+ if mode == 'train' and self.iter >= self._max_iters:
+ break
+ iter_runner(iter_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_epoch')
+ self.call_hook('after_run')
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ """Resume model from checkpoint.
+
+ Args:
+ checkpoint (str): Checkpoint to resume from.
+ resume_optimizer (bool, optional): Whether resume the optimizer(s)
+ if the checkpoint file includes optimizer(s). Default to True.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default to 'default'.
+ """
+ if map_location == 'default':
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ self._inner_iter = checkpoint['meta']['iter']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='iter_{}.pth',
+ meta=None,
+ save_optimizer=True,
+ create_symlink=True):
+ """Save checkpoint to file.
+
+ Args:
+ out_dir (str): Directory to save checkpoint files.
+ filename_tmpl (str, optional): Checkpoint file template.
+ Defaults to 'iter_{}.pth'.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ Defaults to None.
+ save_optimizer (bool, optional): Whether save optimizer.
+ Defaults to True.
+ create_symlink (bool, optional): Whether create symlink to the
+ latest checkpoint file. Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.iter + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ custom_hooks_config=None):
+ """Register default hooks for iter-based training.
+
+ Checkpoint hook, optimizer stepper hook and logger hooks will be set to
+ `by_epoch=False` by default.
+
+ Default hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ if checkpoint_config is not None:
+ checkpoint_config.setdefault('by_epoch', False)
+ if lr_config is not None:
+ lr_config.setdefault('by_epoch', False)
+ if log_config is not None:
+ for info in log_config['hooks']:
+ info.setdefault('by_epoch', False)
+ super(IterBasedRunner, self).register_training_hooks(
+ lr_config=lr_config,
+ momentum_config=momentum_config,
+ optimizer_config=optimizer_config,
+ checkpoint_config=checkpoint_config,
+ log_config=log_config,
+ timer_config=IterTimerHook(),
+ custom_hooks_config=custom_hooks_config)
diff --git a/annotator/uniformer/mmcv/runner/log_buffer.py b/annotator/uniformer/mmcv/runner/log_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/log_buffer.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import numpy as np
+
+
+class LogBuffer:
+
+ def __init__(self):
+ self.val_history = OrderedDict()
+ self.n_history = OrderedDict()
+ self.output = OrderedDict()
+ self.ready = False
+
+ def clear(self):
+ self.val_history.clear()
+ self.n_history.clear()
+ self.clear_output()
+
+ def clear_output(self):
+ self.output.clear()
+ self.ready = False
+
+ def update(self, vars, count=1):
+ assert isinstance(vars, dict)
+ for key, var in vars.items():
+ if key not in self.val_history:
+ self.val_history[key] = []
+ self.n_history[key] = []
+ self.val_history[key].append(var)
+ self.n_history[key].append(count)
+
+ def average(self, n=0):
+ """Average latest n values or all values."""
+ assert n >= 0
+ for key in self.val_history:
+ values = np.array(self.val_history[key][-n:])
+ nums = np.array(self.n_history[key][-n:])
+ avg = np.sum(values * nums) / np.sum(nums)
+ self.output[key] = avg
+ self.ready = True
diff --git a/annotator/uniformer/mmcv/runner/optimizer/__init__.py b/annotator/uniformer/mmcv/runner/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
+ build_optimizer_constructor)
+from .default_constructor import DefaultOptimizerConstructor
+
+__all__ = [
+ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
+ 'build_optimizer', 'build_optimizer_constructor'
+]
diff --git a/annotator/uniformer/mmcv/runner/optimizer/builder.py b/annotator/uniformer/mmcv/runner/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+
+import torch
+
+from ...utils import Registry, build_from_cfg
+
+OPTIMIZERS = Registry('optimizer')
+OPTIMIZER_BUILDERS = Registry('optimizer builder')
+
+
+def register_torch_optimizers():
+ torch_optimizers = []
+ for module_name in dir(torch.optim):
+ if module_name.startswith('__'):
+ continue
+ _optim = getattr(torch.optim, module_name)
+ if inspect.isclass(_optim) and issubclass(_optim,
+ torch.optim.Optimizer):
+ OPTIMIZERS.register_module()(_optim)
+ torch_optimizers.append(module_name)
+ return torch_optimizers
+
+
+TORCH_OPTIMIZERS = register_torch_optimizers()
+
+
+def build_optimizer_constructor(cfg):
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+
+
+def build_optimizer(model, cfg):
+ optimizer_cfg = copy.deepcopy(cfg)
+ constructor_type = optimizer_cfg.pop('constructor',
+ 'DefaultOptimizerConstructor')
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+ optim_constructor = build_optimizer_constructor(
+ dict(
+ type=constructor_type,
+ optimizer_cfg=optimizer_cfg,
+ paramwise_cfg=paramwise_cfg))
+ optimizer = optim_constructor(model)
+ return optimizer
diff --git a/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c0da3503b75441738efe38d70352b55a210a34a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
@@ -0,0 +1,249 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+from torch.nn import GroupNorm, LayerNorm
+
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
+from annotator.uniformer.mmcv.utils.ext_loader import check_ops_exist
+from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class DefaultOptimizerConstructor:
+ """Default constructor for optimizers.
+
+ By default each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain the following fields:
+
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+ be ignored. It should be noted that the aforementioned ``key`` is the
+ longest key that is a substring of the name of the parameter. If there
+ are multiple matched keys with the same length, then the key with lower
+ alphabet order will be chosen.
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+ and ``decay_mult``. See Example 2 below.
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
+ rate for all bias parameters (except for those in normalization
+ layers and offset layers of DCN).
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in
+ normalization layers, depthwise conv layers, offset layers of DCN).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization
+ layers.
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of depthwise conv
+ layers.
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+ rate for parameters of offset layer in the deformable convs
+ of a model.
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+ would not be added into optimizer. Default: False.
+
+ Note:
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ override the effect of ``bias_lr_mult`` in the bias of offset
+ layer. So be careful when using both ``bias_lr_mult`` and
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the
+ offset layer in deformable convs, set ``dcn_offset_lr_mult``
+ to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ apply it to all the DCN layers in the model. So be careful when
+ the model contains multiple DCN layers in places other than
+ backbone.
+
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are
+
+ - `type`: class name of the optimizer.
+
+ Optional fields are
+
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+ paramwise_cfg (dict, optional): Parameter-wise options.
+
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+ >>> weight_decay=0.0001)
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+
+ Example 2:
+ >>> # assume model have attribute model.backbone and model.cls_head
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
+ >>> paramwise_cfg = dict(custom_keys={
+ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+ >>> # model.cls_head is (0.01, 0.95).
+ """
+
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
+ if not isinstance(optimizer_cfg, dict):
+ raise TypeError('optimizer_cfg should be a dict',
+ f'but got {type(optimizer_cfg)}')
+ self.optimizer_cfg = optimizer_cfg
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+ self.base_lr = optimizer_cfg.get('lr', None)
+ self.base_wd = optimizer_cfg.get('weight_decay', None)
+ self._validate_cfg()
+
+ def _validate_cfg(self):
+ if not isinstance(self.paramwise_cfg, dict):
+ raise TypeError('paramwise_cfg should be None or a dict, '
+ f'but got {type(self.paramwise_cfg)}')
+
+ if 'custom_keys' in self.paramwise_cfg:
+ if not isinstance(self.paramwise_cfg['custom_keys'], dict):
+ raise TypeError(
+ 'If specified, custom_keys must be a dict, '
+ f'but got {type(self.paramwise_cfg["custom_keys"])}')
+ if self.base_wd is None:
+ for key in self.paramwise_cfg['custom_keys']:
+ if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
+ raise ValueError('base_wd should not be None')
+
+ # get base lr and weight decay
+ # weight_decay must be explicitly specified if mult is specified
+ if ('bias_decay_mult' in self.paramwise_cfg
+ or 'norm_decay_mult' in self.paramwise_cfg
+ or 'dwconv_decay_mult' in self.paramwise_cfg):
+ if self.base_wd is None:
+ raise ValueError('base_wd should not be None')
+
+ def _is_in(self, param_group, param_group_list):
+ assert is_list_of(param_group_list, dict)
+ param = set(param_group['params'])
+ param_set = set()
+ for group in param_group_list:
+ param_set.update(set(group['params']))
+
+ return not param.isdisjoint(param_set)
+
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ prefix (str): The prefix of the module
+ is_dcn_module (int|float|None): If the current module is a
+ submodule of DCN, `is_dcn_module` will be passed to
+ control conv_offset layer's learning rate. Defaults to None.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
+
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+ is_dwconv = (
+ isinstance(module, torch.nn.Conv2d)
+ and module.in_channels == module.groups)
+
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+ if bypass_duplicate and self._is_in(param_group, params):
+ warnings.warn(f'{prefix} is duplicate. It is skipped since '
+ f'bypass_duplicate={bypass_duplicate}')
+ continue
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ for key in sorted_keys:
+ if key in f'{prefix}.{name}':
+ is_custom = True
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
+ param_group['lr'] = self.base_lr * lr_mult
+ if self.base_wd is not None:
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
+ param_group['weight_decay'] = self.base_wd * decay_mult
+ break
+
+ if not is_custom:
+ # bias_lr_mult affects all bias parameters
+ # except for norm.bias dcn.conv_offset.bias
+ if name == 'bias' and not (is_norm or is_dcn_module):
+ param_group['lr'] = self.base_lr * bias_lr_mult
+
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
+ and isinstance(module, torch.nn.Conv2d)):
+ # deal with both dcn_offset's bias & weight
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+
+ # apply weight decay policies
+ if self.base_wd is not None:
+ # norm decay
+ if is_norm:
+ param_group[
+ 'weight_decay'] = self.base_wd * norm_decay_mult
+ # depth-wise conv
+ elif is_dwconv:
+ param_group[
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
+ # bias lr and decay
+ elif name == 'bias' and not is_dcn_module:
+ # TODO: current bias_decay_mult will have affect on DCN
+ param_group[
+ 'weight_decay'] = self.base_wd * bias_decay_mult
+ params.append(param_group)
+
+ if check_ops_exist():
+ from annotator.uniformer.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+ is_dcn_module = isinstance(module,
+ (DeformConv2d, ModulatedDeformConv2d))
+ else:
+ is_dcn_module = False
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ is_dcn_module=is_dcn_module)
+
+ def __call__(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+
+ optimizer_cfg = self.optimizer_cfg.copy()
+ # if no paramwise option is specified, just use the global setting
+ if not self.paramwise_cfg:
+ optimizer_cfg['params'] = model.parameters()
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+
+ # set param-wise lr and weight decay recursively
+ params = []
+ self.add_params(params, model)
+ optimizer_cfg['params'] = params
+
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
diff --git a/annotator/uniformer/mmcv/runner/priority.py b/annotator/uniformer/mmcv/runner/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/priority.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+
+class Priority(Enum):
+ """Hook priority levels.
+
+ +--------------+------------+
+ | Level | Value |
+ +==============+============+
+ | HIGHEST | 0 |
+ +--------------+------------+
+ | VERY_HIGH | 10 |
+ +--------------+------------+
+ | HIGH | 30 |
+ +--------------+------------+
+ | ABOVE_NORMAL | 40 |
+ +--------------+------------+
+ | NORMAL | 50 |
+ +--------------+------------+
+ | BELOW_NORMAL | 60 |
+ +--------------+------------+
+ | LOW | 70 |
+ +--------------+------------+
+ | VERY_LOW | 90 |
+ +--------------+------------+
+ | LOWEST | 100 |
+ +--------------+------------+
+ """
+
+ HIGHEST = 0
+ VERY_HIGH = 10
+ HIGH = 30
+ ABOVE_NORMAL = 40
+ NORMAL = 50
+ BELOW_NORMAL = 60
+ LOW = 70
+ VERY_LOW = 90
+ LOWEST = 100
+
+
+def get_priority(priority):
+ """Get priority value.
+
+ Args:
+ priority (int or str or :obj:`Priority`): Priority.
+
+ Returns:
+ int: The priority value.
+ """
+ if isinstance(priority, int):
+ if priority < 0 or priority > 100:
+ raise ValueError('priority must be between 0 and 100')
+ return priority
+ elif isinstance(priority, Priority):
+ return priority.value
+ elif isinstance(priority, str):
+ return Priority[priority.upper()].value
+ else:
+ raise TypeError('priority must be an integer or Priority enum value')
diff --git a/annotator/uniformer/mmcv/runner/utils.py b/annotator/uniformer/mmcv/runner/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5befb8e56ece50b5fecfd007b26f8a29124c0bd
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+import sys
+import time
+import warnings
+from getpass import getuser
+from socket import gethostname
+
+import numpy as np
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+
+
+def get_host_info():
+ """Get hostname and username.
+
+ Return empty string if exception raised, e.g. ``getpass.getuser()`` will
+ lead to error in docker container
+ """
+ host = ''
+ try:
+ host = f'{getuser()}@{gethostname()}'
+ except Exception as e:
+ warnings.warn(f'Host or user not found: {str(e)}')
+ finally:
+ return host
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def obj_from_dict(info, parent=None, default_args=None):
+ """Initialize an object from dict.
+
+ The dict must contain the key "type", which indicates the object type, it
+ can be either a string or type, such as "list" or ``list``. Remaining
+ fields are treated as the arguments for constructing the object.
+
+ Args:
+ info (dict): Object types and arguments.
+ parent (:class:`module`): Module which may containing expected object
+ classes.
+ default_args (dict, optional): Default arguments for initializing the
+ object.
+
+ Returns:
+ any type: Object built from the dict.
+ """
+ assert isinstance(info, dict) and 'type' in info
+ assert isinstance(default_args, dict) or default_args is None
+ args = info.copy()
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if parent is not None:
+ obj_type = getattr(parent, obj_type)
+ else:
+ obj_type = sys.modules[obj_type]
+ elif not isinstance(obj_type, type):
+ raise TypeError('type must be a str or valid type, but '
+ f'got {type(obj_type)}')
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_type(**args)
+
+
+def set_random_seed(seed, deterministic=False, use_rank_shift=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ rank_shift (bool): Whether to add rank number to the random seed to
+ have different random seed in different threads. Default: False.
+ """
+ if use_rank_shift:
+ rank, _ = mmcv.runner.get_dist_info()
+ seed += rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/annotator/uniformer/mmcv/utils/__init__.py b/annotator/uniformer/mmcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/__init__.py
@@ -0,0 +1,69 @@
+# flake8: noqa
+# Copyright (c) OpenMMLab. All rights reserved.
+from .config import Config, ConfigDict, DictAction
+from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
+ has_method, import_modules_from_strings, is_list_of,
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
+ iter_cast, list_cast, requires_executable, requires_package,
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+ to_ntuple, tuple_cast)
+from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
+ scandir, symlink)
+from .progressbar import (ProgressBar, track_iter_progress,
+ track_parallel_progress, track_progress)
+from .testing import (assert_attrs_equal, assert_dict_contains_subset,
+ assert_dict_has_keys, assert_is_norm_layer,
+ assert_keys_equal, assert_params_all_zeros,
+ check_python_script)
+from .timer import Timer, TimerError, check_time
+from .version_utils import digit_version, get_git_hash
+
+try:
+ import torch
+except ImportError:
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
+ 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
+ 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
+ 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
+ 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
+ 'track_progress', 'track_iter_progress', 'track_parallel_progress',
+ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
+ 'digit_version', 'get_git_hash', 'import_modules_from_strings',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+ 'is_method_overridden', 'has_method'
+ ]
+else:
+ from .env import collect_env
+ from .logging import get_logger, print_log
+ from .parrots_jit import jit, skip_no_elena
+ from .parrots_wrapper import (
+ TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
+ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
+ _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
+ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
+ from .registry import Registry, build_from_cfg
+ from .trace import is_jit_tracing
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
+ 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
+ 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
+ 'check_prerequisites', 'requires_package', 'requires_executable',
+ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
+ 'symlink', 'scandir', 'ProgressBar', 'track_progress',
+ 'track_iter_progress', 'track_parallel_progress', 'Registry',
+ 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
+ '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
+ '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
+ 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
+ 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
+ 'deprecated_api_warning', 'digit_version', 'get_git_hash',
+ 'import_modules_from_strings', 'jit', 'skip_no_elena',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
+ 'assert_params_all_zeros', 'check_python_script',
+ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
+ '_get_cuda_home', 'has_method'
+ ]
diff --git a/annotator/uniformer/mmcv/utils/config.py b/annotator/uniformer/mmcv/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..17149353aefac6d737c67bb2f35a3a6cd2147b0a
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/config.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+DEPRECATION_KEY = '_deprecation_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+
+
+class ConfigDict(Dict):
+
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=''):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument('--' + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument('--' + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument('--' + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument('--' + prefix + k, action='store_true')
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + '.')
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+ else:
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}: {e}')
+
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname)
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
+ value = value.replace('\\', '/')
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+ base_var_dict[randstr] = base_var
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split('.'):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(
+ v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg)
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split('.'):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname)
+ if platform.system() == 'Windows':
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename,
+ temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name)
+
+ if filename.endswith('.py'):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith(('.yml', '.yaml', '.json')):
+ import annotator.uniformer.mmcv as mmcv
+ cfg_dict = mmcv.load(temp_config_file.name)
+ # close temp file
+ temp_config_file.close()
+
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = f'The config file {filename} will be deprecated ' \
+ 'in the future.'
+ if 'expected' in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
+ 'instead.'
+ if 'reference' in deprecation_info:
+ warning_msg += ' More information can be found at ' \
+ f'{deprecation_info["reference"]}'
+ warnings.warn(warning_msg)
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError('Duplicate key is not allowed among bases. '
+ f'Duplicate keys: {duplicate_keys}')
+ base_cfg_dict.update(c)
+
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+ base_cfg_dict)
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v,
+ dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename,
+ use_predefined_variables=True,
+ import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename,
+ use_predefined_variables)
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
+ import_modules_from_strings(**cfg_dict['custom_imports'])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+ if file_format != '.py' and 'dict(' in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn(
+ 'Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ 'w', encoding='utf-8', suffix=file_format,
+ delete=False) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument('config', help='config file path')
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument('config', help='config file path')
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(Config, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(Config, self).__setattr__('_text', text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
+ super(Config, self).__setattr__('_filename', _filename)
+ super(Config, self).__setattr__('_text', _text)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+ if self.filename.endswith('.py'):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w', encoding='utf-8') as f:
+ f.write(self.pretty_text)
+ else:
+ import annotator.uniformer.mmcv as mmcv
+ if file is None:
+ file_format = self.filename.split('.')[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ super(Config, self).__setattr__(
+ '_cfg_dict',
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count('(') == string.count(')')) and (
+ string.count('[') == string.count(']')), \
+ f'Imbalanced brackets exist in {string}'
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
+ and (pre.count('[') == pre.count(']'))):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip('\'\"').replace(' ', '')
+ is_tuple = False
+ if val.startswith('(') and val.endswith(')'):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith('[') and val.endswith(']'):
+ val = val[1:-1]
+ elif ',' not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1:]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/annotator/uniformer/mmcv/utils/env.py b/annotator/uniformer/mmcv/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f0d92529e193e6d8339419bcd9bed7901a7769
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+
+import os.path as osp
+import subprocess
+import sys
+from collections import defaultdict
+
+import cv2
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+from .parrots_wrapper import get_build_config
+
+
+def collect_env():
+ """Collect the information of the running environments.
+
+ Returns:
+ dict: The environment information. The following fields are contained.
+
+ - sys.platform: The variable of ``sys.platform``.
+ - Python: Python version.
+ - CUDA available: Bool, indicating if CUDA is available.
+ - GPU devices: Device type of each GPU.
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+ - NVCC (optional): NVCC version.
+ - GCC: GCC version, "n/a" if GCC is not installed.
+ - PyTorch: PyTorch version.
+ - PyTorch compiling details: The output of \
+ ``torch.__config__.show()``.
+ - TorchVision (optional): TorchVision version.
+ - OpenCV: OpenCV version.
+ - MMCV: MMCV version.
+ - MMCV Compiler: The GCC version for compiling MMCV ops.
+ - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
+ """
+ env_info = {}
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+
+ cuda_available = torch.cuda.is_available()
+ env_info['CUDA available'] = cuda_available
+
+ if cuda_available:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+
+ from annotator.uniformer.mmcv.utils.parrots_wrapper import _get_cuda_home
+ CUDA_HOME = _get_cuda_home()
+ env_info['CUDA_HOME'] = CUDA_HOME
+
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+ try:
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(
+ f'"{nvcc}" -V | tail -n1', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+
+ try:
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+ gcc = gcc.decode('utf-8').strip()
+ env_info['GCC'] = gcc
+ except subprocess.CalledProcessError: # gcc is unavailable
+ env_info['GCC'] = 'n/a'
+
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = get_build_config()
+
+ try:
+ import torchvision
+ env_info['TorchVision'] = torchvision.__version__
+ except ModuleNotFoundError:
+ pass
+
+ env_info['OpenCV'] = cv2.__version__
+
+ env_info['MMCV'] = mmcv.__version__
+
+ try:
+ from annotator.uniformer.mmcv.ops import get_compiler_version, get_compiling_cuda_version
+ except ModuleNotFoundError:
+ env_info['MMCV Compiler'] = 'n/a'
+ env_info['MMCV CUDA Compiler'] = 'n/a'
+ else:
+ env_info['MMCV Compiler'] = get_compiler_version()
+ env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
+
+ return env_info
diff --git a/annotator/uniformer/mmcv/utils/ext_loader.py b/annotator/uniformer/mmcv/utils/ext_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/ext_loader.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import pkgutil
+import warnings
+from collections import namedtuple
+
+import torch
+
+if torch.__version__ != 'parrots':
+
+ def load_ext(name, funcs):
+ ext = importlib.import_module('mmcv.' + name)
+ for fun in funcs:
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
+ return ext
+else:
+ from parrots import extension
+ from parrots.base import ParrotsException
+
+ has_return_value_ops = [
+ 'nms',
+ 'softnms',
+ 'nms_match',
+ 'nms_rotated',
+ 'top_pool_forward',
+ 'top_pool_backward',
+ 'bottom_pool_forward',
+ 'bottom_pool_backward',
+ 'left_pool_forward',
+ 'left_pool_backward',
+ 'right_pool_forward',
+ 'right_pool_backward',
+ 'fused_bias_leakyrelu',
+ 'upfirdn2d',
+ 'ms_deform_attn_forward',
+ 'pixel_group',
+ 'contour_expand',
+ ]
+
+ def get_fake_func(name, e):
+
+ def fake_func(*args, **kwargs):
+ warnings.warn(f'{name} is not supported in parrots now')
+ raise e
+
+ return fake_func
+
+ def load_ext(name, funcs):
+ ExtModule = namedtuple('ExtModule', funcs)
+ ext_list = []
+ lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ for fun in funcs:
+ try:
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
+ except ParrotsException as e:
+ if 'No element registered' not in e.message:
+ warnings.warn(e.message)
+ ext_fun = get_fake_func(fun, e)
+ ext_list.append(ext_fun)
+ else:
+ if fun in has_return_value_ops:
+ ext_list.append(ext_fun.op)
+ else:
+ ext_list.append(ext_fun.op_)
+ return ExtModule(*ext_list)
+
+
+def check_ops_exist():
+ ext_loader = pkgutil.find_loader('mmcv._ext')
+ return ext_loader is not None
diff --git a/annotator/uniformer/mmcv/utils/logging.py b/annotator/uniformer/mmcv/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/logging.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.distributed as dist
+
+logger_initialized = {}
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == 'silent':
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ 'logger should be either a logging.Logger object, str, '
+ f'"silent" or None, but got {type(logger)}')
diff --git a/annotator/uniformer/mmcv/utils/misc.py b/annotator/uniformer/mmcv/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/misc.py
@@ -0,0 +1,377 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections.abc
+import functools
+import itertools
+import subprocess
+import warnings
+from collections import abc
+from importlib import import_module
+from inspect import getfullargspec
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+
+ Returns:
+ list[module] | module | None: The imported modules.
+
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(
+ f'custom_imports must be a list but got type {type(imports)}')
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(
+ f'{imp} is of type {type(imp)} and cannot be imported.')
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f'{imp} failed to import and is ignored.',
+ UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
+
+
+def iter_cast(inputs, dst_type, return_type=None):
+ """Cast elements of an iterable object into some type.
+
+ Args:
+ inputs (Iterable): The input object.
+ dst_type (type): Destination type.
+ return_type (type, optional): If specified, the output object will be
+ converted to this type, otherwise an iterator.
+
+ Returns:
+ iterator or specified type: The converted object.
+ """
+ if not isinstance(inputs, abc.Iterable):
+ raise TypeError('inputs must be an iterable object')
+ if not isinstance(dst_type, type):
+ raise TypeError('"dst_type" must be a valid type')
+
+ out_iterable = map(dst_type, inputs)
+
+ if return_type is None:
+ return out_iterable
+ else:
+ return return_type(out_iterable)
+
+
+def list_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a list of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=list)
+
+
+def tuple_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a tuple of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=tuple)
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_list_of(seq, expected_type):
+ """Check whether it is a list of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=list)
+
+
+def is_tuple_of(seq, expected_type):
+ """Check whether it is a tuple of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=tuple)
+
+
+def slice_list(in_list, lens):
+ """Slice a list into several sub lists by a list of given length.
+
+ Args:
+ in_list (list): The list to be sliced.
+ lens(int or list): The expected length of each out list.
+
+ Returns:
+ list: A list of sliced list.
+ """
+ if isinstance(lens, int):
+ assert len(in_list) % lens == 0
+ lens = [lens] * int(len(in_list) / lens)
+ if not isinstance(lens, list):
+ raise TypeError('"indices" must be an integer or a list of integers')
+ elif sum(lens) != len(in_list):
+ raise ValueError('sum of lens and list length does not '
+ f'match: {sum(lens)} != {len(in_list)}')
+ out_list = []
+ idx = 0
+ for i in range(len(lens)):
+ out_list.append(in_list[idx:idx + lens[i]])
+ idx += lens[i]
+ return out_list
+
+
+def concat_list(in_list):
+ """Concatenate a list of list into a single list.
+
+ Args:
+ in_list (list): The list of list to be merged.
+
+ Returns:
+ list: The concatenated flat list.
+ """
+ return list(itertools.chain(*in_list))
+
+
+def check_prerequisites(
+ prerequisites,
+ checker,
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+ 'found, please install them first.'): # yapf: disable
+ """A decorator factory to check if prerequisites are satisfied.
+
+ Args:
+ prerequisites (str of list[str]): Prerequisites to be checked.
+ checker (callable): The checker method that returns True if a
+ prerequisite is meet, False otherwise.
+ msg_tmpl (str): The message template with two variables.
+
+ Returns:
+ decorator: A specific decorator.
+ """
+
+ def wrap(func):
+
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ requirements = [prerequisites] if isinstance(
+ prerequisites, str) else prerequisites
+ missing = []
+ for item in requirements:
+ if not checker(item):
+ missing.append(item)
+ if missing:
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
+ raise RuntimeError('Prerequisites not meet.')
+ else:
+ return func(*args, **kwargs)
+
+ return wrapped_func
+
+ return wrap
+
+
+def _check_py_package(package):
+ try:
+ import_module(package)
+ except ImportError:
+ return False
+ else:
+ return True
+
+
+def _check_executable(cmd):
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
+ return False
+ else:
+ return True
+
+
+def requires_package(prerequisites):
+ """A decorator to check if some python packages are installed.
+
+ Example:
+ >>> @requires_package('numpy')
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ array([0.])
+ >>> @requires_package(['numpy', 'non_package'])
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ ImportError
+ """
+ return check_prerequisites(prerequisites, checker=_check_py_package)
+
+
+def requires_executable(prerequisites):
+ """A decorator to check if some executable files are installed.
+
+ Example:
+ >>> @requires_executable('ffmpeg')
+ >>> func(arg1, args):
+ >>> print(1)
+ 1
+ """
+ return check_prerequisites(prerequisites, checker=_check_executable)
+
+
+def deprecated_api_warning(name_dict, cls_name=None):
+ """A decorator to check if some arguments are deprecate and try to replace
+ deprecate src_arg_name to dst_arg_name.
+
+ Args:
+ name_dict(dict):
+ key (str): Deprecate argument names.
+ val (str): Expected argument names.
+
+ Returns:
+ func: New function.
+ """
+
+ def api_warning_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get name of the function
+ func_name = old_func.__name__
+ if cls_name is not None:
+ func_name = f'{cls_name}.{func_name}'
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in arg_names:
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
+ if kwargs:
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in kwargs:
+
+ assert dst_arg_name not in kwargs, (
+ f'The expected behavior is to replace '
+ f'the deprecated key `{src_arg_name}` to '
+ f'new key `{dst_arg_name}`, but got them '
+ f'in the arguments at the same time, which '
+ f'is confusing. `{src_arg_name} will be '
+ f'deprecated in the future, please '
+ f'use `{dst_arg_name}` instead.')
+
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
+
+ # apply converted arguments to the decorated method
+ output = old_func(*args, **kwargs)
+ return output
+
+ return new_func
+
+ return api_warning_wrapper
+
+
+def is_method_overridden(method, base_class, derived_class):
+ """Check if a method of base class is overridden in derived class.
+
+ Args:
+ method (str): the method name to check.
+ base_class (type): the class of the base class.
+ derived_class (type | Any): the class or instance of the derived class.
+ """
+ assert isinstance(base_class, type), \
+ "base_class doesn't accept instance, Please pass class instead."
+
+ if not isinstance(derived_class, type):
+ derived_class = derived_class.__class__
+
+ base_method = getattr(base_class, method)
+ derived_method = getattr(derived_class, method)
+ return derived_method != base_method
+
+
+def has_method(obj: object, method: str) -> bool:
+ """Check whether the object has a method.
+
+ Args:
+ method (str): The method name to check.
+ obj (object): The object to check.
+
+ Returns:
+ bool: True if the object has the method else False.
+ """
+ return hasattr(obj, method) and callable(getattr(obj, method))
diff --git a/annotator/uniformer/mmcv/utils/parrots_jit.py b/annotator/uniformer/mmcv/utils/parrots_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/parrots_jit.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+from .parrots_wrapper import TORCH_VERSION
+
+parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
+
+if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
+ from parrots.jit import pat as jit
+else:
+
+ def jit(func=None,
+ check_input=None,
+ full_shape=True,
+ derivate=False,
+ coderize=False,
+ optimize=False):
+
+ def wrapper(func):
+
+ def wrapper_inner(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper_inner
+
+ if func is None:
+ return wrapper
+ else:
+ return func
+
+
+if TORCH_VERSION == 'parrots':
+ from parrots.utils.tester import skip_no_elena
+else:
+
+ def skip_no_elena(func):
+
+ def wrapper(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper
diff --git a/annotator/uniformer/mmcv/utils/parrots_wrapper.py b/annotator/uniformer/mmcv/utils/parrots_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/parrots_wrapper.py
@@ -0,0 +1,107 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import torch
+
+TORCH_VERSION = torch.__version__
+
+
+def is_rocm_pytorch() -> bool:
+ is_rocm = False
+ if TORCH_VERSION != 'parrots':
+ try:
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm = True if ((torch.version.hip is not None) and
+ (ROCM_HOME is not None)) else False
+ except ImportError:
+ pass
+ return is_rocm
+
+
+def _get_cuda_home():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import CUDA_HOME
+ else:
+ if is_rocm_pytorch():
+ from torch.utils.cpp_extension import ROCM_HOME
+ CUDA_HOME = ROCM_HOME
+ else:
+ from torch.utils.cpp_extension import CUDA_HOME
+ return CUDA_HOME
+
+
+def get_build_config():
+ if TORCH_VERSION == 'parrots':
+ from parrots.config import get_build_info
+ return get_build_info()
+ else:
+ return torch.__config__.show()
+
+
+def _get_conv():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ else:
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ return _ConvNd, _ConvTransposeMixin
+
+
+def _get_dataloader():
+ if TORCH_VERSION == 'parrots':
+ from torch.utils.data import DataLoader, PoolDataLoader
+ else:
+ from torch.utils.data import DataLoader
+ PoolDataLoader = DataLoader
+ return DataLoader, PoolDataLoader
+
+
+def _get_extension():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import BuildExtension, Extension
+ CppExtension = partial(Extension, cuda=False)
+ CUDAExtension = partial(Extension, cuda=True)
+ else:
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+ CUDAExtension)
+ return BuildExtension, CppExtension, CUDAExtension
+
+
+def _get_pool():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ else:
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
+
+
+def _get_norm():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
+ else:
+ from torch.nn.modules.instancenorm import _InstanceNorm
+ from torch.nn.modules.batchnorm import _BatchNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
+
+
+_ConvNd, _ConvTransposeMixin = _get_conv()
+DataLoader, PoolDataLoader = _get_dataloader()
+BuildExtension, CppExtension, CUDAExtension = _get_extension()
+_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
+_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
+
+
+class SyncBatchNorm(SyncBatchNorm_):
+
+ def _check_input_dim(self, input):
+ if TORCH_VERSION == 'parrots':
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input (got {input.dim()}D input)')
+ else:
+ super()._check_input_dim(input)
diff --git a/annotator/uniformer/mmcv/utils/path.py b/annotator/uniformer/mmcv/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/path.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError('`filepath` should be a string or a Path')
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+ item.lower() for item in suffix)
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive,
+ case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=('.git', )):
+ """Finds the root directory (including itself) of specified markers.
+
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/annotator/uniformer/mmcv/utils/progressbar.py b/annotator/uniformer/mmcv/utils/progressbar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/progressbar.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+
+from .timer import Timer
+
+
+class ProgressBar:
+ """A progress bar which can print the progress."""
+
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+ self.task_num = task_num
+ self.bar_width = bar_width
+ self.completed = 0
+ self.file = file
+ if start:
+ self.start()
+
+ @property
+ def terminal_width(self):
+ width, _ = get_terminal_size()
+ return width
+
+ def start(self):
+ if self.task_num > 0:
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+ 'elapsed: 0s, ETA:')
+ else:
+ self.file.write('completed: 0, elapsed: 0s')
+ self.file.flush()
+ self.timer = Timer()
+
+ def update(self, num_tasks=1):
+ assert num_tasks > 0
+ self.completed += num_tasks
+ elapsed = self.timer.since_start()
+ if elapsed > 0:
+ fps = self.completed / elapsed
+ else:
+ fps = float('inf')
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+ f'ETA: {eta:5}s'
+
+ bar_width = min(self.bar_width,
+ int(self.terminal_width - len(msg)) + 2,
+ int(self.terminal_width * 0.6))
+ bar_width = max(2, bar_width)
+ mark_width = int(bar_width * percentage)
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+ self.file.write(msg.format(bar_chars))
+ else:
+ self.file.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+ f' {fps:.1f} tasks/s')
+ self.file.flush()
+
+
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+
+ Tasks are done with a simple for-loop.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+ file=sys.stdout):
+ """Track the progress of parallel task execution with a progress bar.
+
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ pool.close()
+ pool.join()
+ return results
+
+
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+ """Track the progress of tasks iteration or enumeration with a progress
+ bar.
+
+ Tasks are yielded with a simple for-loop.
+
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ prog_bar.file.write('\n')
diff --git a/annotator/uniformer/mmcv/utils/registry.py b/annotator/uniformer/mmcv/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/registry.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ if default_args is None or 'type' not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f'but got {cfg}\n{default_args}')
+ if not isinstance(registry, Registry):
+ raise TypeError('registry must be an mmcv.Registry object, '
+ f'but got {type(registry)}')
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError('default_args must be a dict or None, '
+ f'but got {type(default_args)}')
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop('type')
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ f'{obj_type} is not in the {registry.name} registry')
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f'{obj_cls.__name__}: {e}')
+
+
+class Registry:
+ """A registry to map strings to classes.
+
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + \
+ f'(name={self._name}, ' \
+ f'items={self._module_dict})'
+ return format_str
+
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+
+ The name of the package where registry is defined will be returned.
+
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+
+
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split('.')
+ return split_filename[0]
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+
+ The first scope will be split from key.
+
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find('.')
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1:]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key (str): The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert registry.scope not in self.children, \
+ f'scope {registry.scope} exists in {self.name} registry'
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError('module must be a class, '
+ f'but got {type(module_class)}')
+
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f'{name} is already registered '
+ f'in {self.name}')
+ self._module_dict[name] = module_class
+
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ 'The old API of register_module(module, force=False) '
+ 'is deprecated and will be removed, please use the new API '
+ 'register_module(name=None, force=False, module=None) instead.')
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ 'name must be either of None, an instance of str or a sequence'
+ f' of str, but got {type(name)}')
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(
+ module_class=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(
+ module_class=cls, module_name=name, force=force)
+ return cls
+
+ return _register
diff --git a/annotator/uniformer/mmcv/utils/testing.py b/annotator/uniformer/mmcv/utils/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/testing.py
@@ -0,0 +1,140 @@
+# Copyright (c) Open-MMLab.
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Dict, List
+from unittest.mock import patch
+
+
+def check_python_script(cmd):
+ """Run the python cmd script with `__main__`. The difference between
+ `os.system` is that, this function exectues code in the current process, so
+ that it can be tracked by coverage tools. Currently it supports two forms:
+
+ - ./tests/data/scripts/hello.py zz
+ - python tests/data/scripts/hello.py zz
+ """
+ args = split(cmd)
+ if args[0] == 'python':
+ args = args[1:]
+ with patch.object(sys, 'argv', args):
+ run_path(args[0], run_name='__main__')
+
+
+def _any(judge_result):
+ """Since built-in ``any`` works only when the element of iterable is not
+ iterable, implement the function."""
+ if not isinstance(judge_result, Iterable):
+ return judge_result
+
+ try:
+ for element in judge_result:
+ if _any(element):
+ return True
+ except TypeError:
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+ if judge_result:
+ return True
+ return False
+
+
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+ expected_subset: Dict[Any, Any]) -> bool:
+ """Check if the dict_obj contains the expected_subset.
+
+ Args:
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
+ dict_obj.
+
+ Returns:
+ bool: Whether the dict_obj contains the expected_subset.
+ """
+
+ for key, value in expected_subset.items():
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+ return False
+ return True
+
+
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+ """Check if attribute of class object is correct.
+
+ Args:
+ obj (object): Class object to be checked.
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+
+ Returns:
+ bool: Whether the attribute of class object is correct.
+ """
+ for attr, value in expected_attrs.items():
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+ return False
+ return True
+
+
+def assert_dict_has_keys(obj: Dict[str, Any],
+ expected_keys: List[str]) -> bool:
+ """Check if the obj has all the expected_keys.
+
+ Args:
+ obj (Dict[str, Any]): Object to be checked.
+ expected_keys (List[str]): Keys expected to contained in the keys of
+ the obj.
+
+ Returns:
+ bool: Whether the obj has the expected keys.
+ """
+ return set(expected_keys).issubset(set(obj.keys()))
+
+
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+ """Check if target_keys is equal to result_keys.
+
+ Args:
+ result_keys (List[str]): Result keys to be checked.
+ target_keys (List[str]): Target keys to be checked.
+
+ Returns:
+ bool: Whether target_keys is equal to result_keys.
+ """
+ return set(result_keys) == set(target_keys)
+
+
+def assert_is_norm_layer(module) -> bool:
+ """Check if the module is a norm layer.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the module is a norm layer.
+ """
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
+ from torch.nn import GroupNorm, LayerNorm
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+ return isinstance(module, norm_layer_candidates)
+
+
+def assert_params_all_zeros(module) -> bool:
+ """Check if the parameters of the module is all zeros.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the parameters of the module is all zeros.
+ """
+ weight_data = module.weight.data
+ is_weight_zero = weight_data.allclose(
+ weight_data.new_zeros(weight_data.size()))
+
+ if hasattr(module, 'bias') and module.bias is not None:
+ bias_data = module.bias.data
+ is_bias_zero = bias_data.allclose(
+ bias_data.new_zeros(bias_data.size()))
+ else:
+ is_bias_zero = True
+
+ return is_weight_zero and is_bias_zero
diff --git a/annotator/uniformer/mmcv/utils/timer.py b/annotator/uniformer/mmcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3db7d497d8b374e18b5297e0a1d6eb186fd8cba
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/timer.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from time import time
+
+
+class TimerError(Exception):
+
+ def __init__(self, message):
+ self.message = message
+ super(TimerError, self).__init__(message)
+
+
+class Timer:
+ """A flexible Timer class.
+
+ :Example:
+
+ >>> import time
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+ if start:
+ self.start()
+
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+
+ def since_start(self):
+ """Total time since the timer is started.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ self._t_last = time()
+ return self._t_last - self._t_start
+
+ def since_last_check(self):
+ """Time since the last checking.
+
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+
+
+_g_timers = {} # global timers
+
+
+def check_time(timer_id):
+ """Add check points in a single line.
+
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+
+ :Example:
+
+ >>> import time
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+
+ Args:
+ timer_id (str): Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/annotator/uniformer/mmcv/utils/trace.py b/annotator/uniformer/mmcv/utils/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca99dc3eda05ef980d9a4249b50deca8273b6cc
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/trace.py
@@ -0,0 +1,23 @@
+import warnings
+
+import torch
+
+from annotator.uniformer.mmcv.utils import digit_version
+
+
+def is_jit_tracing() -> bool:
+ if (torch.__version__ != 'parrots'
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
+ on_trace = torch.jit.is_tracing()
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
+ if isinstance(on_trace, bool):
+ return on_trace
+ else:
+ return torch._C._is_tracing()
+ else:
+ warnings.warn(
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
+ 'Therefore is_tracing returns False automatically. Please '
+ 'set on_trace manually if you are using trace.', UserWarning)
+ return False
diff --git a/annotator/uniformer/mmcv/utils/version_utils.py b/annotator/uniformer/mmcv/utils/version_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/version_utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import subprocess
+import warnings
+
+from packaging.version import parse
+
+
+def digit_version(version_str: str, length: int = 4):
+ """Convert a version string into a tuple of integers.
+
+ This method is usually used for comparing two versions. For pre-release
+ versions: alpha < beta < rc.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int]: The version info in digits (integers).
+ """
+ assert 'parrots' not in version_str
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
+ val = -4
+ # version.pre can be None
+ if version.pre:
+ if version.pre[0] not in mapping:
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+ 'version checking may go wrong')
+ else:
+ val = mapping[version.pre[0]]
+ release.extend([val, version.pre[-1]])
+ else:
+ release.extend([val, 0])
+
+ elif version.is_postrelease:
+ release.extend([1, version.post])
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+
+def get_git_hash(fallback='unknown', digits=None):
+ """Get the git hash of the current repo.
+
+ Args:
+ fallback (str, optional): The fallback string when git hash is
+ unavailable. Defaults to 'unknown'.
+ digits (int, optional): kept digits of the hash. Defaults to None,
+ meaning all digits are kept.
+
+ Returns:
+ str: Git commit hash.
+ """
+
+ if digits is not None and not isinstance(digits, int):
+ raise TypeError('digits must be None or an integer')
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ if digits is not None:
+ sha = sha[:digits]
+ except OSError:
+ sha = fallback
+
+ return sha
diff --git a/annotator/uniformer/mmcv/version.py b/annotator/uniformer/mmcv/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0
--- /dev/null
+++ b/annotator/uniformer/mmcv/version.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+__version__ = '1.3.17'
+
+
+def parse_version_info(version_str: str, length: int = 4) -> tuple:
+ """Parse a version string into a tuple.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
+ """
+ from packaging.version import parse
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ release.extend(list(version.pre))
+ elif version.is_postrelease:
+ release.extend(list(version.post))
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+version_info = tuple(int(x) for x in __version__.split('.')[:3])
+
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/annotator/uniformer/mmcv/video/__init__.py b/annotator/uniformer/mmcv/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .io import Cache, VideoReader, frames2video
+from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
+from .processing import concat_video, convert_video, cut_video, resize_video
+
+__all__ = [
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
+]
diff --git a/annotator/uniformer/mmcv/video/io.py b/annotator/uniformer/mmcv/video/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9879154227f640c262853b92c219461c6f67ee8e
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/io.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict
+
+import cv2
+from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+
+from annotator.uniformer.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
+ track_progress)
+
+
+class Cache:
+
+ def __init__(self, capacity):
+ self._cache = OrderedDict()
+ self._capacity = int(capacity)
+ if capacity <= 0:
+ raise ValueError('capacity must be a positive integer')
+
+ @property
+ def capacity(self):
+ return self._capacity
+
+ @property
+ def size(self):
+ return len(self._cache)
+
+ def put(self, key, val):
+ if key in self._cache:
+ return
+ if len(self._cache) >= self.capacity:
+ self._cache.popitem(last=False)
+ self._cache[key] = val
+
+ def get(self, key, default=None):
+ val = self._cache[key] if key in self._cache else default
+ return val
+
+
+class VideoReader:
+ """Video class with similar usage to a list object.
+
+ This video warpper class provides convenient apis to access frames.
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
+ certain frame may be inaccurate. It is fixed in this class by checking
+ the position after jumping each time.
+ Cache is used when decoding videos. So if the same frame is visited for
+ the second time, there is no need to decode again if it is stored in the
+ cache.
+
+ :Example:
+
+ >>> import annotator.uniformer.mmcv as mmcv
+ >>> v = mmcv.VideoReader('sample.mp4')
+ >>> len(v) # get the total frame number with `len()`
+ 120
+ >>> for img in v: # v is iterable
+ >>> mmcv.imshow(img)
+ >>> v[5] # get the 6th frame
+ """
+
+ def __init__(self, filename, cache_capacity=10):
+ # Check whether the video path is a url
+ if not filename.startswith(('https://', 'http://')):
+ check_file_exist(filename, 'Video file not found: ' + filename)
+ self._vcap = cv2.VideoCapture(filename)
+ assert cache_capacity > 0
+ self._cache = Cache(cache_capacity)
+ self._position = 0
+ # get basic info
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+ self._fps = self._vcap.get(CAP_PROP_FPS)
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+
+ @property
+ def vcap(self):
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+ return self._vcap
+
+ @property
+ def opened(self):
+ """bool: Indicate whether the video is opened."""
+ return self._vcap.isOpened()
+
+ @property
+ def width(self):
+ """int: Width of video frames."""
+ return self._width
+
+ @property
+ def height(self):
+ """int: Height of video frames."""
+ return self._height
+
+ @property
+ def resolution(self):
+ """tuple: Video resolution (width, height)."""
+ return (self._width, self._height)
+
+ @property
+ def fps(self):
+ """float: FPS of the video."""
+ return self._fps
+
+ @property
+ def frame_cnt(self):
+ """int: Total frames of the video."""
+ return self._frame_cnt
+
+ @property
+ def fourcc(self):
+ """str: "Four character code" of the video."""
+ return self._fourcc
+
+ @property
+ def position(self):
+ """int: Current cursor position, indicating frame decoded."""
+ return self._position
+
+ def _get_real_position(self):
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+
+ def _set_real_position(self, frame_id):
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+ pos = self._get_real_position()
+ for _ in range(frame_id - pos):
+ self._vcap.read()
+ self._position = frame_id
+
+ def read(self):
+ """Read the next frame.
+
+ If the next frame have been decoded before and in the cache, then
+ return it directly, otherwise decode, cache and return it.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ # pos = self._position
+ if self._cache:
+ img = self._cache.get(self._position)
+ if img is not None:
+ ret = True
+ else:
+ if self._position != self._get_real_position():
+ self._set_real_position(self._position)
+ ret, img = self._vcap.read()
+ if ret:
+ self._cache.put(self._position, img)
+ else:
+ ret, img = self._vcap.read()
+ if ret:
+ self._position += 1
+ return img
+
+ def get_frame(self, frame_id):
+ """Get frame by index.
+
+ Args:
+ frame_id (int): Index of the expected frame, 0-based.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ if frame_id < 0 or frame_id >= self._frame_cnt:
+ raise IndexError(
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+ if frame_id == self._position:
+ return self.read()
+ if self._cache:
+ img = self._cache.get(frame_id)
+ if img is not None:
+ self._position = frame_id + 1
+ return img
+ self._set_real_position(frame_id)
+ ret, img = self._vcap.read()
+ if ret:
+ if self._cache:
+ self._cache.put(self._position, img)
+ self._position += 1
+ return img
+
+ def current_frame(self):
+ """Get the current frame (frame that is just visited).
+
+ Returns:
+ ndarray or None: If the video is fresh, return None, otherwise
+ return the frame.
+ """
+ if self._position == 0:
+ return None
+ return self._cache.get(self._position - 1)
+
+ def cvt2frames(self,
+ frame_dir,
+ file_start=0,
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ max_num=0,
+ show_progress=True):
+ """Convert a video to frame images.
+
+ Args:
+ frame_dir (str): Output directory to store all the frame images.
+ file_start (int): Filenames will start from the specified number.
+ filename_tmpl (str): Filename template with the index as the
+ placeholder.
+ start (int): The starting frame index.
+ max_num (int): Maximum number of frames to be written.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ mkdir_or_exist(frame_dir)
+ if max_num == 0:
+ task_num = self.frame_cnt - start
+ else:
+ task_num = min(self.frame_cnt - start, max_num)
+ if task_num <= 0:
+ raise ValueError('start must be less than total frame number')
+ if start > 0:
+ self._set_real_position(start)
+
+ def write_frame(file_idx):
+ img = self.read()
+ if img is None:
+ return
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ cv2.imwrite(filename, img)
+
+ if show_progress:
+ track_progress(write_frame, range(file_start,
+ file_start + task_num))
+ else:
+ for i in range(task_num):
+ write_frame(file_start + i)
+
+ def __len__(self):
+ return self.frame_cnt
+
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return [
+ self.get_frame(i)
+ for i in range(*index.indices(self.frame_cnt))
+ ]
+ # support negative indexing
+ if index < 0:
+ index += self.frame_cnt
+ if index < 0:
+ raise IndexError('index out of range')
+ return self.get_frame(index)
+
+ def __iter__(self):
+ self._set_real_position(0)
+ return self
+
+ def __next__(self):
+ img = self.read()
+ if img is not None:
+ return img
+ else:
+ raise StopIteration
+
+ next = __next__
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._vcap.release()
+
+
+def frames2video(frame_dir,
+ video_file,
+ fps=30,
+ fourcc='XVID',
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ end=0,
+ show_progress=True):
+ """Read the frame images from a directory and join them as a video.
+
+ Args:
+ frame_dir (str): The directory containing video frames.
+ video_file (str): Output filename.
+ fps (float): FPS of the output video.
+ fourcc (str): Fourcc of the output video, this should be compatible
+ with the output file type.
+ filename_tmpl (str): Filename template with the index as the variable.
+ start (int): Starting frame index.
+ end (int): Ending frame index.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ if end == 0:
+ ext = filename_tmpl.split('.')[-1]
+ end = len([name for name in scandir(frame_dir, ext)])
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
+ img = cv2.imread(first_file)
+ height, width = img.shape[:2]
+ resolution = (width, height)
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+ resolution)
+
+ def write_frame(file_idx):
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ img = cv2.imread(filename)
+ vwriter.write(img)
+
+ if show_progress:
+ track_progress(write_frame, range(start, end))
+ else:
+ for i in range(start, end):
+ write_frame(i)
+ vwriter.release()
diff --git a/annotator/uniformer/mmcv/video/optflow.py b/annotator/uniformer/mmcv/video/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..84160f8d6ef9fceb5a2f89e7481593109fc1905d
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/optflow.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import cv2
+import numpy as np
+
+from annotator.uniformer.mmcv.arraymisc import dequantize, quantize
+from annotator.uniformer.mmcv.image import imread, imwrite
+from annotator.uniformer.mmcv.utils import is_str
+
+
+def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_or_path (ndarray or str): A flow map or filepath.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if isinstance(flow_or_path, np.ndarray):
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
+ return flow_or_path
+ elif not is_str(flow_or_path):
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
+ f'not {type(flow_or_path)}')
+
+ if not quantize:
+ with open(flow_or_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_or_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
+ 'header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+ else:
+ assert concat_axis in [0, 1]
+ cat_flow = imread(flow_or_path, flag='unchanged')
+ if cat_flow.ndim != 2:
+ raise IOError(
+ f'{flow_or_path} is not a valid quantized flow file, '
+ f'its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+ ]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
+ """Use flow to warp img.
+
+ Args:
+ img (ndarray, float or uint8): Image to be warped.
+ flow (ndarray, float): Optical Flow.
+ filling_value (int): The missing pixels will be set with filling_value.
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
+ nearest -> Nearest Neighbor.
+
+ Returns:
+ ndarray: Warped image with the same shape of img
+ """
+ warnings.warn('This function is just for prototyping and cannot '
+ 'guarantee the computational efficiency.')
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
+ height = flow.shape[0]
+ width = flow.shape[1]
+ channels = img.shape[2]
+
+ output = np.ones(
+ (height, width, channels), dtype=img.dtype) * filling_value
+
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
+ dx = grid[:, :, 0] + flow[:, :, 1]
+ dy = grid[:, :, 1] + flow[:, :, 0]
+ sx = np.floor(dx).astype(int)
+ sy = np.floor(dy).astype(int)
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
+
+ if interpolate_mode == 'nearest':
+ output[valid, :] = img[dx[valid].round().astype(int),
+ dy[valid].round().astype(int), :]
+ elif interpolate_mode == 'bilinear':
+ # dirty walkround for integer positions
+ eps_ = 1e-6
+ dx, dy = dx + eps_, dy + eps_
+ left_top_ = img[np.floor(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ right_top_ = img[np.floor(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
+ else:
+ raise NotImplementedError(
+ 'We only support interpolation modes of nearest and bilinear, '
+ f'but got {interpolate_mode}.')
+ return output.astype(img.dtype)
+
+
+def flow_from_bytes(content):
+ """Read dense optical flow from bytes.
+
+ .. note::
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
+ ChairsSDHom.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ ndarray: Loaded optical flow with the shape (H, W, 2).
+ """
+
+ # header in first 4 bytes
+ header = content[:4]
+ if header.decode('utf-8') != 'PIEH':
+ raise Exception('Flow file header does not contain PIEH')
+ # width in second 4 bytes
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
+ # height in third 4 bytes
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
+ # after first 12 bytes, all bytes are flow
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
+ (height, width, 2))
+
+ return flow
+
+
+def sparse_flow_from_bytes(content):
+ """Read the optical flow in KITTI datasets from bytes.
+
+ This function is modified from RAFT load the `KITTI datasets
+ `_.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
+ and flow valid mask with the shape (H, W).
+ """ # nopa
+
+ content = np.frombuffer(content, np.uint8)
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ # flow shape (H, W, 2) valid shape (H, W)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
diff --git a/annotator/uniformer/mmcv/video/processing.py b/annotator/uniformer/mmcv/video/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d90b96e0823d5f116755e7f498d25d17017224a
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/processing.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+
+from annotator.uniformer.mmcv.utils import requires_executable
+
+
+@requires_executable('ffmpeg')
+def convert_video(in_file,
+ out_file,
+ print_cmd=False,
+ pre_options='',
+ **kwargs):
+ """Convert a video with ffmpeg.
+
+ This provides a general api to ffmpeg, the executed command is::
+
+ `ffmpeg -y -i `
+
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
+
+ - key=val: "-key val"
+ - key=True: "-key"
+ - key=False: ""
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ pre_options (str): Options appears before "-i ".
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = []
+ for k, v in kwargs.items():
+ if isinstance(v, bool):
+ if v:
+ options.append(f'-{k}')
+ elif k == 'log_level':
+ assert v in [
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
+ 'verbose', 'debug', 'trace'
+ ]
+ options.append(f'-loglevel {v}')
+ else:
+ options.append(f'-{k} {v}')
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
+ f'{out_file}'
+ if print_cmd:
+ print(cmd)
+ subprocess.call(cmd, shell=True)
+
+
+@requires_executable('ffmpeg')
+def resize_video(in_file,
+ out_file,
+ size=None,
+ ratio=None,
+ keep_ar=False,
+ log_level='info',
+ print_cmd=False):
+ """Resize a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
+ (w*2, h*0.5).
+ keep_ar (bool): Whether to keep original aspect ratio.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ if size is None and ratio is None:
+ raise ValueError('expected size or ratio must be specified')
+ if size is not None and ratio is not None:
+ raise ValueError('size and ratio cannot be specified at the same time')
+ options = {'log_level': log_level}
+ if size:
+ if not keep_ar:
+ options['vf'] = f'scale={size[0]}:{size[1]}'
+ else:
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
+ 'force_original_aspect_ratio=decrease'
+ else:
+ if not isinstance(ratio, tuple):
+ ratio = (ratio, ratio)
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def cut_video(in_file,
+ out_file,
+ start=None,
+ end=None,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Cut a clip from a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ start (None or float): Start time (in seconds).
+ end (None or float): End time (in seconds).
+ vcodec (None or str): Output video codec, None for unchanged.
+ acodec (None or str): Output audio codec, None for unchanged.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ if start:
+ options['ss'] = start
+ else:
+ start = 0
+ if end:
+ options['t'] = end - start
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def concat_video(video_list,
+ out_file,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Concatenate multiple videos into a single one.
+
+ Args:
+ video_list (list): A list of video filenames
+ out_file (str): Output video filename
+ vcodec (None or str): Output video codec, None for unchanged
+ acodec (None or str): Output audio codec, None for unchanged
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
+ with open(tmp_filename, 'w') as f:
+ for filename in video_list:
+ f.write(f'file {osp.abspath(filename)}\n')
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ convert_video(
+ tmp_filename,
+ out_file,
+ print_cmd,
+ pre_options='-f concat -safe 0',
+ **options)
+ os.close(tmp_filehandler)
+ os.remove(tmp_filename)
diff --git a/annotator/uniformer/mmcv/visualization/__init__.py b/annotator/uniformer/mmcv/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .color import Color, color_val
+from .image import imshow, imshow_bboxes, imshow_det_bboxes
+from .optflow import flow2rgb, flowshow, make_color_wheel
+
+__all__ = [
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
+]
diff --git a/annotator/uniformer/mmcv/visualization/color.py b/annotator/uniformer/mmcv/visualization/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..9041e0e6b7581c3356795d6a3c5e84667c88f025
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/color.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+import numpy as np
+
+from annotator.uniformer.mmcv.utils import is_str
+
+
+class Color(Enum):
+ """An enum that defines common colors.
+
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
+ """
+ red = (0, 0, 255)
+ green = (0, 255, 0)
+ blue = (255, 0, 0)
+ cyan = (255, 255, 0)
+ yellow = (0, 255, 255)
+ magenta = (255, 0, 255)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+
+
+def color_val(color):
+ """Convert various input to color tuples.
+
+ Args:
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+
+ Returns:
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
+ """
+ if is_str(color):
+ return Color[color].value
+ elif isinstance(color, Color):
+ return color.value
+ elif isinstance(color, tuple):
+ assert len(color) == 3
+ for channel in color:
+ assert 0 <= channel <= 255
+ return color
+ elif isinstance(color, int):
+ assert 0 <= color <= 255
+ return color, color, color
+ elif isinstance(color, np.ndarray):
+ assert color.ndim == 1 and color.size == 3
+ assert np.all((color >= 0) & (color <= 255))
+ color = color.astype(np.uint8)
+ return tuple(color)
+ else:
+ raise TypeError(f'Invalid type for color: {type(color)}')
diff --git a/annotator/uniformer/mmcv/visualization/image.py b/annotator/uniformer/mmcv/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a56c75b67f593c298408462c63c0468be8e276
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/image.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from annotator.uniformer.mmcv.image import imread, imwrite
+from .color import color_val
+
+
+def imshow(img, win_name='', wait_time=0):
+ """Show an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ cv2.imshow(win_name, imread(img))
+ if wait_time == 0: # prevent from hanging if windows was closed
+ while True:
+ ret = cv2.waitKey(1)
+
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
+ # if user closed window or if some key pressed
+ if closed or ret != -1:
+ break
+ else:
+ ret = cv2.waitKey(wait_time)
+
+
+def imshow_bboxes(img,
+ bboxes,
+ colors='green',
+ top_k=-1,
+ thickness=1,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
+ colors (list[str or tuple or Color]): A list of colors.
+ top_k (int): Plot the first k bboxes only if set positive.
+ thickness (int): Thickness of lines.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str, optional): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if isinstance(bboxes, np.ndarray):
+ bboxes = [bboxes]
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+
+ for i, _bboxes in enumerate(bboxes):
+ _bboxes = _bboxes.astype(np.int32)
+ if top_k <= 0:
+ _top_k = _bboxes.shape[0]
+ else:
+ _top_k = min(top_k, _bboxes.shape[0])
+ for j in range(_top_k):
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
+ cv2.rectangle(
+ img, left_top, right_bottom, colors[i], thickness=thickness)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
+
+
+def imshow_det_bboxes(img,
+ bboxes,
+ labels,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes.ndim == 2
+ assert labels.ndim == 1
+ assert bboxes.shape[0] == labels.shape[0]
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if score_thr > 0:
+ assert bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+
+ bbox_color = color_val(bbox_color)
+ text_color = color_val(text_color)
+
+ for bbox, label in zip(bboxes, labels):
+ bbox_int = bbox.astype(np.int32)
+ left_top = (bbox_int[0], bbox_int[1])
+ right_bottom = (bbox_int[2], bbox_int[3])
+ cv2.rectangle(
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
+ label_text = class_names[
+ label] if class_names is not None else f'cls {label}'
+ if len(bbox) > 4:
+ label_text += f'|{bbox[-1]:.02f}'
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
diff --git a/annotator/uniformer/mmcv/visualization/optflow.py b/annotator/uniformer/mmcv/visualization/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3870c700f7c946177ee5d536ce3f6c814a77ce7
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/optflow.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+
+import numpy as np
+
+from annotator.uniformer.mmcv.image import rgb2bgr
+from annotator.uniformer.mmcv.video import flowread
+from .image import imshow
+
+
+def flowshow(flow, win_name='', wait_time=0):
+ """Show optical flow.
+
+ Args:
+ flow (ndarray or str): The optical flow to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ flow = flowread(flow)
+ flow_img = flow2rgb(flow)
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
+
+
+def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
+ """Convert flow map to RGB image.
+
+ Args:
+ flow (ndarray): Array of optical flow.
+ color_wheel (ndarray or None): Color wheel used to map flow field to
+ RGB colorspace. Default color wheel will be used if not specified.
+ unknown_thr (str): Values above this threshold will be marked as
+ unknown and thus ignored.
+
+ Returns:
+ ndarray: RGB image that can be visualized.
+ """
+ assert flow.ndim == 3 and flow.shape[-1] == 2
+ if color_wheel is None:
+ color_wheel = make_color_wheel()
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
+ num_bins = color_wheel.shape[0]
+
+ dx = flow[:, :, 0].copy()
+ dy = flow[:, :, 1].copy()
+
+ ignore_inds = (
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
+ (np.abs(dy) > unknown_thr))
+ dx[ignore_inds] = 0
+ dy[ignore_inds] = 0
+
+ rad = np.sqrt(dx**2 + dy**2)
+ if np.any(rad > np.finfo(float).eps):
+ max_rad = np.max(rad)
+ dx /= max_rad
+ dy /= max_rad
+
+ rad = np.sqrt(dx**2 + dy**2)
+ angle = np.arctan2(-dy, -dx) / np.pi
+
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
+ bin_left = np.floor(bin_real).astype(int)
+ bin_right = (bin_left + 1) % num_bins
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
+ flow_img = (1 -
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
+ small_ind = rad <= 1
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
+ flow_img[np.logical_not(small_ind)] *= 0.75
+
+ flow_img[ignore_inds, :] = 0
+
+ return flow_img
+
+
+def make_color_wheel(bins=None):
+ """Build a color wheel.
+
+ Args:
+ bins(list or tuple, optional): Specify the number of bins for each
+ color range, corresponding to six ranges: red -> yellow,
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
+ (see Middlebury).
+
+ Returns:
+ ndarray: Color wheel of shape (total_bins, 3).
+ """
+ if bins is None:
+ bins = [15, 6, 4, 11, 13, 6]
+ assert len(bins) == 6
+
+ RY, YG, GC, CB, BM, MR = tuple(bins)
+
+ ry = [1, np.arange(RY) / RY, 0]
+ yg = [1 - np.arange(YG) / YG, 1, 0]
+ gc = [0, 1, np.arange(GC) / GC]
+ cb = [0, 1 - np.arange(CB) / CB, 1]
+ bm = [np.arange(BM) / BM, 0, 1]
+ mr = [1, 0, 1 - np.arange(MR) / MR]
+
+ num_bins = RY + YG + GC + CB + BM + MR
+
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
+
+ col = 0
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
+ for j in range(3):
+ color_wheel[j, col:col + bins[i]] = color[j]
+ col += bins[i]
+
+ return color_wheel.T
diff --git a/annotator/uniformer/mmcv_custom/__init__.py b/annotator/uniformer/mmcv_custom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536
--- /dev/null
+++ b/annotator/uniformer/mmcv_custom/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+
+from .checkpoint import load_checkpoint
+
+__all__ = ['load_checkpoint']
\ No newline at end of file
diff --git a/annotator/uniformer/mmcv_custom/checkpoint.py b/annotator/uniformer/mmcv_custom/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b87fef0a52d31babcdb3edb8f3089b6420173f
--- /dev/null
+++ b/annotator/uniformer/mmcv_custom/checkpoint.py
@@ -0,0 +1,500 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.fileio import load as load_file
+from annotator.uniformer.mmcv.parallel import is_module_wrapper
+from annotator.uniformer.mmcv.utils import mkdir_or_exist
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(
+ downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ allowed_backends = ['ceph']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+ if rank == 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict | OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('torchvision://'):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('open-mmlab://'):
+ model_urls = get_external_models()
+ model_name = filename[13:]
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+ f'of open-mmlab://{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(model_url)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ elif filename.startswith('mmcls://'):
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ elif filename.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(filename)
+ elif filename.startswith('pavi://'):
+ model_path = filename[7:]
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+ elif filename.startswith('s3://'):
+ checkpoint = load_fileclient_dist(
+ filename, backend='ceph', map_location=map_location)
+ else:
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+def load_checkpoint(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H*W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1 ** 0.5)
+ S2 = int(L2 ** 0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
+ size=(S2, S2), mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
\ No newline at end of file
diff --git a/annotator/uniformer/mmseg/apis/__init__.py b/annotator/uniformer/mmseg/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/__init__.py
@@ -0,0 +1,9 @@
+from .inference import inference_segmentor, init_segmentor, show_result_pyplot
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_segmentor
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
+ 'show_result_pyplot'
+]
diff --git a/annotator/uniformer/mmseg/apis/inference.py b/annotator/uniformer/mmseg/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bc1c0c68525734bd6793f07c15fe97d3c8342c
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/inference.py
@@ -0,0 +1,136 @@
+import matplotlib.pyplot as plt
+import annotator.uniformer.mmcv as mmcv
+import torch
+from annotator.uniformer.mmcv.parallel import collate, scatter
+from annotator.uniformer.mmcv.runner import load_checkpoint
+
+from annotator.uniformer.mmseg.datasets.pipelines import Compose
+from annotator.uniformer.mmseg.models import build_segmentor
+
+
+def init_segmentor(config, checkpoint=None, device='cuda:0'):
+ """Initialize a segmentor from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+ Use 'cpu' for loading model on CPU.
+ Returns:
+ nn.Module: The constructed segmentor.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ 'but got {}'.format(type(config)))
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ model.PALETTE = checkpoint['meta']['PALETTE']
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+class LoadImage:
+ """A simple pipeline to load image."""
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_segmentor(model, img):
+ """Inference image(s) with the segmentor.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+ images.
+
+ Returns:
+ (list[Tensor]): The segmentation result.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # build the data pipeline
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+ test_pipeline = Compose(test_pipeline)
+ # prepare data
+ data = dict(img=img)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
+
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ return result
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ palette=None,
+ fig_size=(15, 10),
+ opacity=0.5,
+ title='',
+ block=True):
+ """Visualize the segmentation results on the image.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (list): The segmentation result.
+ palette (list[list[int]]] | None): The palette of segmentation
+ map. If None is given, random palette will be generated.
+ Default: None
+ fig_size (tuple): Figure size of the pyplot figure.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ title (str): The title of pyplot figure.
+ Default is ''.
+ block (bool): Whether to block the pyplot figure.
+ Default is True.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ img = model.show_result(
+ img, result, palette=palette, show=False, opacity=opacity)
+ # plt.figure(figsize=fig_size)
+ # plt.imshow(mmcv.bgr2rgb(img))
+ # plt.title(title)
+ # plt.tight_layout()
+ # plt.show(block=block)
+ return mmcv.bgr2rgb(img)
diff --git a/annotator/uniformer/mmseg/apis/test.py b/annotator/uniformer/mmseg/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e574eb7da04f09a59cf99ff953c36468ae87a326
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/test.py
@@ -0,0 +1,238 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from annotator.uniformer.mmcv.image import tensor2imgs
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+
+def np2tmp(array, temp_file_name=None):
+ """Save ndarray to local numpy file.
+
+ Args:
+ array (ndarray): Ndarray to save.
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
+ function will generate a file name with tempfile.NamedTemporaryFile
+ to save ndarray. Default: None.
+
+ Returns:
+ str: The numpy file name.
+ """
+
+ if temp_file_name is None:
+ temp_file_name = tempfile.NamedTemporaryFile(
+ suffix='.npy', delete=False).name
+ np.save(temp_file_name, array)
+ return temp_file_name
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ efficient_test=False,
+ opacity=0.5):
+ """Test with single GPU.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ show (bool): Whether show results during inference. Default: False.
+ out_dir (str, optional): If specified, the results will be dumped into
+ the directory to save output results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+
+ if show or out_dir:
+ img_tensor = data['img'][0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+
+ for img, img_meta in zip(imgs, img_metas):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result,
+ palette=dataset.PALETTE,
+ show=show,
+ out_file=out_file,
+ opacity=opacity)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model,
+ data_loader,
+ tmpdir=None,
+ gpu_collect=False,
+ efficient_test=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ if rank == 0:
+ batch_size = data['img'][0].size(0)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results with CPU."""
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ tmpdir = tempfile.mkdtemp()
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results with GPU."""
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/annotator/uniformer/mmseg/apis/train.py b/annotator/uniformer/mmseg/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f319a919ff023931a6a663e668f27dd1a07a2e
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/train.py
@@ -0,0 +1,116 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from annotator.uniformer.mmcv.runner import build_optimizer, build_runner
+
+from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook
+from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset
+from annotator.uniformer.mmseg.utils import get_root_logger
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_segmentor(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ """Launch segmentor training."""
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ drop_last=True) for ds in dataset
+ ]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ if cfg.get('runner') is None:
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+
+ # an ugly walkaround to make the .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # register eval hooks
+ if validate:
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/annotator/uniformer/mmseg/core/__init__.py b/annotator/uniformer/mmseg/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import * # noqa: F401, F403
+from .seg import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/annotator/uniformer/mmseg/core/evaluation/__init__.py b/annotator/uniformer/mmseg/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .class_names import get_classes, get_palette
+from .eval_hooks import DistEvalHook, EvalHook
+from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
+
+__all__ = [
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
+ 'eval_metrics', 'get_classes', 'get_palette'
+]
diff --git a/annotator/uniformer/mmseg/core/evaluation/class_names.py b/annotator/uniformer/mmseg/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffae816cf980ce4b03e491cc0c4298cb823797e6
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/class_names.py
@@ -0,0 +1,152 @@
+import annotator.uniformer.mmcv as mmcv
+
+
+def cityscapes_classes():
+ """Cityscapes class names for external use."""
+ return [
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+def ade_classes():
+ """ADE20K class names for external use."""
+ return [
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag'
+ ]
+
+
+def voc_classes():
+ """Pascal VOC class names for external use."""
+ return [
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor'
+ ]
+
+
+def cityscapes_palette():
+ """Cityscapes palette for external use."""
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32]]
+
+
+def ade_palette():
+ """ADE20K palette for external use."""
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+
+def voc_palette():
+ """Pascal VOC palette for external use."""
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+
+dataset_aliases = {
+ 'cityscapes': ['cityscapes'],
+ 'ade': ['ade', 'ade20k'],
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
+}
+
+
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
+
+
+def get_palette(dataset):
+ """Get class palette (RGB) of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_palette()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc100c8f96e817a6ed2666f7c9f762af2463b48
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
@@ -0,0 +1,109 @@
+import os.path as osp
+
+from annotator.uniformer.mmcv.runner import DistEvalHook as _DistEvalHook
+from annotator.uniformer.mmcv.runner import EvalHook as _EvalHook
+
+
+class EvalHook(_EvalHook):
+ """Single GPU EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(
+ runner.model,
+ self.dataloader,
+ show=False,
+ efficient_test=self.efficient_test)
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ self.evaluate(runner, results)
+
+
+class DistEvalHook(_DistEvalHook):
+ """Distributed EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect,
+ efficient_test=self.efficient_test)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.uniformer.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
diff --git a/annotator/uniformer/mmseg/core/evaluation/metrics.py b/annotator/uniformer/mmseg/core/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c7dd47cadd53cf1caaa194e28a343f2aacc599
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/metrics.py
@@ -0,0 +1,326 @@
+from collections import OrderedDict
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+
+
+def f_score(precision, recall, beta=1):
+ """calcuate the f-score value.
+
+ Args:
+ precision (float | torch.Tensor): The precision value.
+ recall (float | torch.Tensor): The recall value.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+ Returns:
+ [torch.tensor]: The f-score value.
+ """
+ score = (1 + beta**2) * (precision * recall) / (
+ (beta**2 * precision) + recall)
+ return score
+
+
+def intersect_and_union(pred_label,
+ label,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate intersection and Union.
+
+ Args:
+ pred_label (ndarray | str): Prediction segmentation map
+ or predict result filename.
+ label (ndarray | str): Ground truth segmentation map
+ or label filename.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. The parameter will
+ work only when label is str. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. The parameter will
+ work only when label is str. Default: False.
+
+ Returns:
+ torch.Tensor: The intersection of prediction and ground truth
+ histogram on all classes.
+ torch.Tensor: The union of prediction and ground truth histogram on
+ all classes.
+ torch.Tensor: The prediction histogram on all classes.
+ torch.Tensor: The ground truth histogram on all classes.
+ """
+
+ if isinstance(pred_label, str):
+ pred_label = torch.from_numpy(np.load(pred_label))
+ else:
+ pred_label = torch.from_numpy((pred_label))
+
+ if isinstance(label, str):
+ label = torch.from_numpy(
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
+ else:
+ label = torch.from_numpy(label)
+
+ if label_map is not None:
+ for old_id, new_id in label_map.items():
+ label[label == old_id] = new_id
+ if reduce_zero_label:
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+
+ mask = (label != ignore_index)
+ pred_label = pred_label[mask]
+ label = label[mask]
+
+ intersect = pred_label[pred_label == label]
+ area_intersect = torch.histc(
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_pred_label = torch.histc(
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_label = torch.histc(
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_union = area_pred_label + area_label - area_intersect
+ return area_intersect, area_union, area_pred_label, area_label
+
+
+def total_intersect_and_union(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Total Intersection and Union.
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+ Returns:
+ ndarray: The intersection of prediction and ground truth histogram
+ on all classes.
+ ndarray: The union of prediction and ground truth histogram on all
+ classes.
+ ndarray: The prediction histogram on all classes.
+ ndarray: The ground truth histogram on all classes.
+ """
+ num_imgs = len(results)
+ assert len(gt_seg_maps) == num_imgs
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ for i in range(num_imgs):
+ area_intersect, area_union, area_pred_label, area_label = \
+ intersect_and_union(
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
+ label_map, reduce_zero_label)
+ total_area_intersect += area_intersect
+ total_area_union += area_union
+ total_area_pred_label += area_pred_label
+ total_area_label += area_label
+ return total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label
+
+
+def mean_iou(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category IoU, shape (num_classes, ).
+ """
+ iou_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return iou_result
+
+
+def mean_dice(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Dice (mDice)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category dice, shape (num_classes, ).
+ """
+
+ dice_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mDice'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return dice_result
+
+
+def mean_fscore(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category recall, shape (num_classes, ).
+ ndarray: Per category precision, shape (num_classes, ).
+ ndarray: Per category f-score, shape (num_classes, ).
+ """
+ fscore_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mFscore'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label,
+ beta=beta)
+ return fscore_result
+
+
+def eval_metrics(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate evaluation metrics
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Wether ignore zero label. Default: False.
+ Returns:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
+ """
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metrics).issubset(set(allowed_metrics)):
+ raise KeyError('metrics {} is not supported'.format(metrics))
+
+ total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label = total_intersect_and_union(
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
+ reduce_zero_label)
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
+ ret_metrics = OrderedDict({'aAcc': all_acc})
+ for metric in metrics:
+ if metric == 'mIoU':
+ iou = total_area_intersect / total_area_union
+ acc = total_area_intersect / total_area_label
+ ret_metrics['IoU'] = iou
+ ret_metrics['Acc'] = acc
+ elif metric == 'mDice':
+ dice = 2 * total_area_intersect / (
+ total_area_pred_label + total_area_label)
+ acc = total_area_intersect / total_area_label
+ ret_metrics['Dice'] = dice
+ ret_metrics['Acc'] = acc
+ elif metric == 'mFscore':
+ precision = total_area_intersect / total_area_pred_label
+ recall = total_area_intersect / total_area_label
+ f_value = torch.tensor(
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
+ ret_metrics['Fscore'] = f_value
+ ret_metrics['Precision'] = precision
+ ret_metrics['Recall'] = recall
+
+ ret_metrics = {
+ metric: value.numpy()
+ for metric, value in ret_metrics.items()
+ }
+ if nan_to_num is not None:
+ ret_metrics = OrderedDict({
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
+ for metric, metric_value in ret_metrics.items()
+ })
+ return ret_metrics
diff --git a/annotator/uniformer/mmseg/core/seg/__init__.py b/annotator/uniformer/mmseg/core/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_pixel_sampler
+from .sampler import BasePixelSampler, OHEMPixelSampler
+
+__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
diff --git a/annotator/uniformer/mmseg/core/seg/builder.py b/annotator/uniformer/mmseg/core/seg/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..db61f03d4abb2072f2532ce4429c0842495e015b
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/builder.py
@@ -0,0 +1,8 @@
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+
+PIXEL_SAMPLERS = Registry('pixel sampler')
+
+
+def build_pixel_sampler(cfg, **default_args):
+ """Build pixel sampler for segmentation map."""
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/__init__.py b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_pixel_sampler import BasePixelSampler
+from .ohem_pixel_sampler import OHEMPixelSampler
+
+__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
@@ -0,0 +1,12 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BasePixelSampler(metaclass=ABCMeta):
+ """Base class of pixel sampler."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def sample(self, seg_logit, seg_label):
+ """Placeholder for sample function."""
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+
+from ..builder import PIXEL_SAMPLERS
+from .base_pixel_sampler import BasePixelSampler
+
+
+@PIXEL_SAMPLERS.register_module()
+class OHEMPixelSampler(BasePixelSampler):
+ """Online Hard Example Mining Sampler for segmentation.
+
+ Args:
+ context (nn.Module): The context of sampler, subclass of
+ :obj:`BaseDecodeHead`.
+ thresh (float, optional): The threshold for hard example selection.
+ Below which, are prediction with low confidence. If not
+ specified, the hard examples will be pixels of top ``min_kept``
+ loss. Default: None.
+ min_kept (int, optional): The minimum number of predictions to keep.
+ Default: 100000.
+ """
+
+ def __init__(self, context, thresh=None, min_kept=100000):
+ super(OHEMPixelSampler, self).__init__()
+ self.context = context
+ assert min_kept > 1
+ self.thresh = thresh
+ self.min_kept = min_kept
+
+ def sample(self, seg_logit, seg_label):
+ """Sample pixels that have high loss or with low prediction confidence.
+
+ Args:
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
+
+ Returns:
+ torch.Tensor: segmentation weight, shape (N, H, W)
+ """
+ with torch.no_grad():
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
+ assert seg_label.shape[1] == 1
+ seg_label = seg_label.squeeze(1).long()
+ batch_kept = self.min_kept * seg_label.size(0)
+ valid_mask = seg_label != self.context.ignore_index
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
+ valid_seg_weight = seg_weight[valid_mask]
+ if self.thresh is not None:
+ seg_prob = F.softmax(seg_logit, dim=1)
+
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
+
+ if sort_prob.numel() > 0:
+ min_threshold = sort_prob[min(batch_kept,
+ sort_prob.numel() - 1)]
+ else:
+ min_threshold = 0.0
+ threshold = max(min_threshold, self.thresh)
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
+ else:
+ losses = self.context.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=None,
+ ignore_index=self.context.ignore_index,
+ reduction_override='none')
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
+ _, sort_indices = losses[valid_mask].sort(descending=True)
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
+
+ seg_weight[valid_mask] = valid_seg_weight
+
+ return seg_weight
diff --git a/annotator/uniformer/mmseg/core/utils/__init__.py b/annotator/uniformer/mmseg/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/utils/__init__.py
@@ -0,0 +1,3 @@
+from .misc import add_prefix
+
+__all__ = ['add_prefix']
diff --git a/annotator/uniformer/mmseg/core/utils/misc.py b/annotator/uniformer/mmseg/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/utils/misc.py
@@ -0,0 +1,17 @@
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f'{prefix}.{name}'] = value
+
+ return outputs
diff --git a/annotator/uniformer/mmseg/datasets/__init__.py b/annotator/uniformer/mmseg/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/__init__.py
@@ -0,0 +1,19 @@
+from .ade import ADE20KDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .chase_db1 import ChaseDB1Dataset
+from .cityscapes import CityscapesDataset
+from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .drive import DRIVEDataset
+from .hrf import HRFDataset
+from .pascal_context import PascalContextDataset, PascalContextDataset59
+from .stare import STAREDataset
+from .voc import PascalVOCDataset
+
+__all__ = [
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
+ 'STAREDataset'
+]
diff --git a/annotator/uniformer/mmseg/datasets/ade.py b/annotator/uniformer/mmseg/datasets/ade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/ade.py
@@ -0,0 +1,84 @@
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ADE20KDataset(CustomDataset):
+ """ADE20K dataset.
+
+ In segmentation map annotation for ADE20K, 0 stands for background, which
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+ CLASSES = (
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ def __init__(self, **kwargs):
+ super(ADE20KDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=True,
+ **kwargs)
diff --git a/annotator/uniformer/mmseg/datasets/builder.py b/annotator/uniformer/mmseg/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0798b14cd8b39fc58d8f2a4930f1e079b5bf8b55
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/builder.py
@@ -0,0 +1,169 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from annotator.uniformer.mmcv.parallel import collate
+from annotator.uniformer.mmcv.runner import get_dist_info
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+from annotator.uniformer.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
+from torch.utils.data import DistributedSampler
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ hard_limit = rlimit[1]
+ soft_limit = min(4096, hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ """Build :obj:`ConcatDataset by."""
+ from .dataset_wrappers import ConcatDataset
+ img_dir = cfg['img_dir']
+ ann_dir = cfg.get('ann_dir', None)
+ split = cfg.get('split', None)
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
+ if ann_dir is not None:
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
+ else:
+ num_ann_dir = 0
+ if split is not None:
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
+ else:
+ num_split = 0
+ if num_img_dir > 1:
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
+ assert num_img_dir == num_split or num_split == 0
+ else:
+ assert num_split == num_ann_dir or num_ann_dir <= 1
+ num_dset = max(num_split, num_img_dir)
+
+ datasets = []
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ if isinstance(img_dir, (list, tuple)):
+ data_cfg['img_dir'] = img_dir[i]
+ if isinstance(ann_dir, (list, tuple)):
+ data_cfg['ann_dir'] = ann_dir[i]
+ if isinstance(split, (list, tuple)):
+ data_cfg['split'] = split[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+ """Build datasets."""
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
+ cfg.get('split', None), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ drop_last=False,
+ pin_memory=True,
+ dataloader_type='PoolDataLoader',
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ seed (int | None): Seed to be used. Default: None.
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
+ Default: False
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
+ Default: True
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=shuffle)
+ shuffle = False
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ assert dataloader_type in (
+ 'DataLoader',
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
+
+ if dataloader_type == 'PoolDataLoader':
+ dataloader = PoolDataLoader
+ elif dataloader_type == 'DataLoader':
+ dataloader = DataLoader
+
+ data_loader = dataloader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ worker_init_fn=init_fn,
+ drop_last=drop_last,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/annotator/uniformer/mmseg/datasets/chase_db1.py b/annotator/uniformer/mmseg/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/chase_db1.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ChaseDB1Dataset(CustomDataset):
+ """Chase_db1 dataset.
+
+ In segmentation map annotation for Chase_db1, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_1stHO.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(ChaseDB1Dataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_1stHO.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/cityscapes.py b/annotator/uniformer/mmseg/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e47a914a1aa2e5458e18669d65ffb742f46fc6
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/cityscapes.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import tempfile
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CustomDataset):
+ """Cityscapes dataset.
+
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
+ """
+
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
+
+ def __init__(self, **kwargs):
+ super(CityscapesDataset, self).__init__(
+ img_suffix='_leftImg8bit.png',
+ seg_map_suffix='_gtFine_labelTrainIds.png',
+ **kwargs)
+
+ @staticmethod
+ def _convert_to_label_id(result):
+ """Convert trainId to id for cityscapes."""
+ if isinstance(result, str):
+ result = np.load(result)
+ import cityscapesscripts.helpers.labels as CSLabels
+ result_copy = result.copy()
+ for trainId, label in CSLabels.trainId2label.items():
+ result_copy[result == trainId] = label.id
+
+ return result_copy
+
+ def results2img(self, results, imgfile_prefix, to_label_id):
+ """Write the segmentation results to images.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ imgfile_prefix (str): The filename prefix of the png files.
+ If the prefix is "somepath/xxx",
+ the png files will be named "somepath/xxx.png".
+ to_label_id (bool): whether convert output to label_id for
+ submission
+
+ Returns:
+ list[str: str]: result txt files which contains corresponding
+ semantic segmentation images.
+ """
+ mmcv.mkdir_or_exist(imgfile_prefix)
+ result_files = []
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ if to_label_id:
+ result = self._convert_to_label_id(result)
+ filename = self.img_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
+
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
+ import cityscapesscripts.helpers.labels as CSLabels
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
+ for label_id, label in CSLabels.id2label.items():
+ palette[label_id] = label.color
+
+ output.putpalette(palette)
+ output.save(png_filename)
+ result_files.append(png_filename)
+ prog_bar.update()
+
+ return result_files
+
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
+ """Format the results into dir (standard format for Cityscapes
+ evaluation).
+
+ Args:
+ results (list): Testing results of the dataset.
+ imgfile_prefix (str | None): The prefix of images files. It
+ includes the file path and the prefix of filename, e.g.,
+ "a/b/prefix". If not specified, a temp file will be created.
+ Default: None.
+ to_label_id (bool): whether convert output to label_id for
+ submission. Default: False
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a list containing
+ the image paths, tmp_dir is the temporal directory created
+ for saving json/png files when img_prefix is not specified.
+ """
+
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: '
+ f'{len(results)} != {len(self)}')
+
+ if imgfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ imgfile_prefix = tmp_dir.name
+ else:
+ tmp_dir = None
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
+
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ imgfile_prefix=None,
+ efficient_test=False):
+ """Evaluation in Cityscapes/default protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file,
+ for cityscapes evaluation only. It includes the file path and
+ the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output png files. The output files would be
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
+ the image name of cityscapes. If not specified, a temp file
+ will be created for evaluation.
+ Default: None.
+
+ Returns:
+ dict[str, float]: Cityscapes/default metrics.
+ """
+
+ eval_results = dict()
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
+ metrics.remove('cityscapes')
+ if len(metrics) > 0:
+ eval_results.update(
+ super(CityscapesDataset,
+ self).evaluate(results, metrics, logger, efficient_test))
+
+ return eval_results
+
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
+ """Evaluation in Cityscapes protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file
+
+ Returns:
+ dict[str: float]: Cityscapes evaluation results.
+ """
+ try:
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install cityscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
+
+ if tmp_dir is None:
+ result_dir = imgfile_prefix
+ else:
+ result_dir = tmp_dir.name
+
+ eval_results = dict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+ CSEval.args.evalInstLevelScore = True
+ CSEval.args.predictionPath = osp.abspath(result_dir)
+ CSEval.args.evalPixelAccuracy = True
+ CSEval.args.JSONOutput = False
+
+ seg_map_list = []
+ pred_list = []
+
+ # when evaluating with official cityscapesscripts,
+ # **_gtFine_labelIds.png is used
+ for seg_map in mmcv.scandir(
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
+
+ eval_results.update(
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+
+ return eval_results
diff --git a/annotator/uniformer/mmseg/datasets/custom.py b/annotator/uniformer/mmseg/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8eb2a709cc7a3a68fc6a1e3a1ad98faef4c5b7b
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/custom.py
@@ -0,0 +1,400 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+from functools import reduce
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from prettytable import PrettyTable
+from torch.utils.data import Dataset
+
+from annotator.uniformer.mmseg.core import eval_metrics
+from annotator.uniformer.mmseg.utils import get_root_logger
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+ """Custom dataset for semantic segmentation. An example of file structure
+ is as followed.
+
+ .. code-block:: none
+
+ ├── data
+ │ ├── my_dataset
+ │ │ ├── img_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── ann_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{seg_map_suffix}
+ │ │ │ │ ├── yyy{seg_map_suffix}
+ │ │ │ │ ├── zzz{seg_map_suffix}
+ │ │ │ ├── val
+
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
+
+
+ Args:
+ pipeline (list[dict]): Processing pipeline
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ ann_dir (str, optional): Path to annotation directory. Default: None
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ split (str, optional): Split txt file. If split is specified, only
+ file with suffix in the splits will be loaded. Otherwise, all
+ images in img_dir/ann_dir will be loaded. Default: None
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
+ None.
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default: False
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, and
+ self.PALETTE is None, random palette will be generated.
+ Default: None
+ """
+
+ CLASSES = None
+
+ PALETTE = None
+
+ def __init__(self,
+ pipeline,
+ img_dir,
+ img_suffix='.jpg',
+ ann_dir=None,
+ seg_map_suffix='.png',
+ split=None,
+ data_root=None,
+ test_mode=False,
+ ignore_index=255,
+ reduce_zero_label=False,
+ classes=None,
+ palette=None):
+ self.pipeline = Compose(pipeline)
+ self.img_dir = img_dir
+ self.img_suffix = img_suffix
+ self.ann_dir = ann_dir
+ self.seg_map_suffix = seg_map_suffix
+ self.split = split
+ self.data_root = data_root
+ self.test_mode = test_mode
+ self.ignore_index = ignore_index
+ self.reduce_zero_label = reduce_zero_label
+ self.label_map = None
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
+ classes, palette)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.img_dir):
+ self.img_dir = osp.join(self.data_root, self.img_dir)
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
+ if not (self.split is None or osp.isabs(self.split)):
+ self.split = osp.join(self.data_root, self.split)
+
+ # load annotations
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
+ self.ann_dir,
+ self.seg_map_suffix, self.split)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.img_infos)
+
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
+ split):
+ """Load annotation from directory.
+
+ Args:
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images.
+ ann_dir (str|None): Path to annotation directory.
+ seg_map_suffix (str|None): Suffix of segmentation maps.
+ split (str|None): Split txt file. If split is specified, only file
+ with suffix in the splits will be loaded. Otherwise, all images
+ in img_dir/ann_dir will be loaded. Default: None
+
+ Returns:
+ list[dict]: All image info of dataset.
+ """
+
+ img_infos = []
+ if split is not None:
+ with open(split) as f:
+ for line in f:
+ img_name = line.strip()
+ img_info = dict(filename=img_name + img_suffix)
+ if ann_dir is not None:
+ seg_map = img_name + seg_map_suffix
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+ else:
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
+ img_info = dict(filename=img)
+ if ann_dir is not None:
+ seg_map = img.replace(img_suffix, seg_map_suffix)
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
+ return img_infos
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.img_infos[idx]['ann']
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['seg_fields'] = []
+ results['img_prefix'] = self.img_dir
+ results['seg_prefix'] = self.ann_dir
+ if self.custom_classes:
+ results['label_map'] = self.label_map
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set
+ False).
+ """
+
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ else:
+ return self.prepare_train_img(idx)
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def get_gt_seg_maps(self, efficient_test=False):
+ """Get ground truth segmentation maps for evaluation."""
+ gt_seg_maps = []
+ for img_info in self.img_infos:
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
+ if efficient_test:
+ gt_seg_map = seg_map
+ else:
+ gt_seg_map = mmcv.imread(
+ seg_map, flag='unchanged', backend='pillow')
+ gt_seg_maps.append(gt_seg_map)
+ return gt_seg_maps
+
+ def get_classes_and_palette(self, classes=None, palette=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, random
+ palette will be generated. Default: None
+ """
+ if classes is None:
+ self.custom_classes = False
+ return self.CLASSES, self.PALETTE
+
+ self.custom_classes = True
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ if self.CLASSES:
+ if not set(classes).issubset(self.CLASSES):
+ raise ValueError('classes is not a subset of CLASSES.')
+
+ # dictionary, its keys are the old label ids and its values
+ # are the new label ids.
+ # used for changing pixel labels in load_annotations.
+ self.label_map = {}
+ for i, c in enumerate(self.CLASSES):
+ if c not in class_names:
+ self.label_map[i] = -1
+ else:
+ self.label_map[i] = classes.index(c)
+
+ palette = self.get_palette_for_custom_classes(class_names, palette)
+
+ return class_names, palette
+
+ def get_palette_for_custom_classes(self, class_names, palette=None):
+
+ if self.label_map is not None:
+ # return subset of palette
+ palette = []
+ for old_id, new_id in sorted(
+ self.label_map.items(), key=lambda x: x[1]):
+ if new_id != -1:
+ palette.append(self.PALETTE[old_id])
+ palette = type(self.PALETTE)(palette)
+
+ elif palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
+ else:
+ palette = self.PALETTE
+
+ return palette
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ efficient_test=False,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
+ 'mDice' and 'mFscore' are supported.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str, float]: Default metrics.
+ """
+
+ if isinstance(metric, str):
+ metric = [metric]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metric).issubset(set(allowed_metrics)):
+ raise KeyError('metric {} is not supported'.format(metric))
+ eval_results = {}
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
+ if self.CLASSES is None:
+ num_classes = len(
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
+ else:
+ num_classes = len(self.CLASSES)
+ ret_metrics = eval_metrics(
+ results,
+ gt_seg_maps,
+ num_classes,
+ self.ignore_index,
+ metric,
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label)
+
+ if self.CLASSES is None:
+ class_names = tuple(range(num_classes))
+ else:
+ class_names = self.CLASSES
+
+ # summary table
+ ret_metrics_summary = OrderedDict({
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+
+ # each class table
+ ret_metrics.pop('aAcc', None)
+ ret_metrics_class = OrderedDict({
+ ret_metric: np.round(ret_metric_value * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+ ret_metrics_class.update({'Class': class_names})
+ ret_metrics_class.move_to_end('Class', last=False)
+
+ # for logger
+ class_table_data = PrettyTable()
+ for key, val in ret_metrics_class.items():
+ class_table_data.add_column(key, val)
+
+ summary_table_data = PrettyTable()
+ for key, val in ret_metrics_summary.items():
+ if key == 'aAcc':
+ summary_table_data.add_column(key, [val])
+ else:
+ summary_table_data.add_column('m' + key, [val])
+
+ print_log('per class results:', logger)
+ print_log('\n' + class_table_data.get_string(), logger=logger)
+ print_log('Summary:', logger)
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
+
+ # each metric dict
+ for key, value in ret_metrics_summary.items():
+ if key == 'aAcc':
+ eval_results[key] = value / 100.0
+ else:
+ eval_results['m' + key] = value / 100.0
+
+ ret_metrics_class.pop('Class', None)
+ for key, value in ret_metrics_class.items():
+ eval_results.update({
+ key + '.' + str(name): value[idx] / 100.0
+ for idx, name in enumerate(class_names)
+ })
+
+ if mmcv.is_list_of(results, str):
+ for file_name in results:
+ os.remove(file_name)
+ return eval_results
diff --git a/annotator/uniformer/mmseg/datasets/dataset_wrappers.py b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
@@ -0,0 +1,50 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.PALETTE = datasets[0].PALETTE
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = dataset.PALETTE
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ """Get item from original dataset."""
+ return self.dataset[idx % self._ori_len]
+
+ def __len__(self):
+ """The length is multiplied by ``times``"""
+ return self.times * self._ori_len
diff --git a/annotator/uniformer/mmseg/datasets/drive.py b/annotator/uniformer/mmseg/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/drive.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class DRIVEDataset(CustomDataset):
+ """DRIVE dataset.
+
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_manual1.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(DRIVEDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_manual1.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/hrf.py b/annotator/uniformer/mmseg/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/hrf.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class HRFDataset(CustomDataset):
+ """HRF dataset.
+
+ In segmentation map annotation for HRF, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(HRFDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/pascal_context.py b/annotator/uniformer/mmseg/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pascal_context.py
@@ -0,0 +1,103 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalContextDataset(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
+ 'window', 'wood')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
+
+
+@DATASETS.register_module()
+class PascalContextDataset59(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
+
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset59, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=True,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/__init__.py b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
@@ -0,0 +1,16 @@
+from .compose import Compose
+from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
+ Transpose, to_tensor)
+from .loading import LoadAnnotations, LoadImageFromFile
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
+]
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/compose.py b/annotator/uniformer/mmseg/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfcbb925c6d4ebf849328b9f94ef6fc24359bf5
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+
+from annotator.uniformer.mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/formating.py b/annotator/uniformer/mmseg/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..97db85f4f9db39fb86ba77ead7d1a8407d810adb
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/formating.py
@@ -0,0 +1,288 @@
+from collections.abc import Sequence
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True),
+ dict(key='gt_semantic_seg'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img',
+ stack=True), dict(key='gt_semantic_seg'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img"
+ and "gt_semantic_seg". These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
+ (3)to DataContainer (stack=True)
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ if 'gt_semantic_seg' in results:
+ # convert to long
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None,
+ ...].astype(np.int64)),
+ stack=True)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "gt_semantic_seg".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple
+ (h, w, c). Note that images may be zero padded on the bottom/right
+ if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/loading.py b/annotator/uniformer/mmseg/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3692ae91f19b9c7ccf6023168788ff42c9e93e3
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/loading.py
@@ -0,0 +1,153 @@
+import os.path as osp
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'cv2'
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='cv2'):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('img_prefix') is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, backend=self.imdecode_backend)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(to_float32={self.to_float32},'
+ repr_str += f"color_type='{self.color_type}',"
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+ """Load annotations for semantic segmentation.
+
+ Args:
+ reduce_zero_label (bool): Whether reduce all label value by 1.
+ Usually used for datasets where 0 is background label.
+ Default: False.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'pillow'
+ """
+
+ def __init__(self,
+ reduce_zero_label=False,
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='pillow'):
+ self.reduce_zero_label = reduce_zero_label
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('seg_prefix', None) is not None:
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ else:
+ filename = results['ann_info']['seg_map']
+ img_bytes = self.file_client.get(filename)
+ gt_semantic_seg = mmcv.imfrombytes(
+ img_bytes, flag='unchanged',
+ backend=self.imdecode_backend).squeeze().astype(np.uint8)
+ # modify if custom classes
+ if results.get('label_map', None) is not None:
+ for old_id, new_id in results['label_map'].items():
+ gt_semantic_seg[gt_semantic_seg == old_id] = new_id
+ # reduce zero_label
+ if self.reduce_zero_label:
+ # avoid using underflow conversion
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
+ gt_semantic_seg = gt_semantic_seg - 1
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
+ results['gt_semantic_seg'] = gt_semantic_seg
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1611a04d9d927223c9afbe5bf68af04d62937a
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,133 @@
+import warnings
+
+import annotator.uniformer.mmcv as mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=(2048, 1024),
+ img_ratios=[0.5, 1.0],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (None | tuple | list[tuple]): Images scales for resizing.
+ img_ratios (float | list[float]): Image ratios for resizing
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale,
+ img_ratios=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ if img_ratios is not None:
+ img_ratios = img_ratios if isinstance(img_ratios,
+ list) else [img_ratios]
+ assert mmcv.is_list_of(img_ratios, float)
+ if img_scale is None:
+ # mode 1: given img_scale=None and a range of image ratio
+ self.img_scale = None
+ assert mmcv.is_list_of(img_ratios, float)
+ elif isinstance(img_scale, tuple) and mmcv.is_list_of(
+ img_ratios, float):
+ assert len(img_scale) == 2
+ # mode 2: given a scale and a range of image ratio
+ self.img_scale = [(int(img_scale[0] * ratio),
+ int(img_scale[1] * ratio))
+ for ratio in img_ratios]
+ else:
+ # mode 3: given multiple scales
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
+ self.flip = flip
+ self.img_ratios = img_ratios
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
+ h, w = results['img'].shape[:2]
+ img_scale = [(int(w * ratio), int(h * ratio))
+ for ratio in self.img_ratios]
+ else:
+ img_scale = self.img_scale
+ flip_aug = [False, True] if self.flip else [False]
+ for scale in img_scale:
+ for flip in flip_aug:
+ for direction in self.flip_direction:
+ _results = results.copy()
+ _results['scale'] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
+ repr_str += f'flip_direction={self.flip_direction}'
+ return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/transforms.py b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..94e869b252ef6d8b43604add2bbc02f034614bfb
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
@@ -0,0 +1,889 @@
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import deprecated_api_warning, is_tuple_of
+from numpy import random
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Resize(object):
+ """Resize images & seg.
+
+ This transform resizes the input image to some scale. If the input dict
+ contains the key "scale", then the scale in the input dict is used,
+ otherwise the specified scale in the init method is used.
+
+ ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 4 multiscale modes:
+
+ - ``ratio_range is not None``:
+ 1. When img_scale is None, img_scale is the shape of image in results
+ (img_scale = results['img'].shape[:2]) and the image is resized based
+ on the original size. (mode 1)
+ 2. When img_scale is a tuple (single-scale), randomly sample a ratio from
+ the ratio range and multiply it with the image scale. (mode 2)
+
+ - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
+ scale from the a range. (mode 3)
+
+ - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
+ scale from multiple scales. (mode 4)
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given img_scale=None and a range of image ratio
+ # mode 2: given a scale and a range of image ratio
+ assert self.img_scale is None or len(self.img_scale) == 1
+ else:
+ # mode 3 and 4: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
+ where ``img_scale`` is the selected image scale and
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where
+ ``img_scale`` is sampled scale and None is just a placeholder
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and
+ None is just a placeholder to be consistent with
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ if self.img_scale is None:
+ h, w = results['img'].shape[:2]
+ scale, scale_idx = self.random_sample_ratio((w, h),
+ self.ratio_range)
+ else:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results['img'], results['scale'], return_scale=True)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results['img'].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results['img'], results['scale'], return_scale=True)
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['pad_shape'] = img.shape # in case that there is no padding
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key], results['scale'], interpolation='nearest')
+ else:
+ gt_seg = mmcv.imresize(
+ results[key], results['scale'], interpolation='nearest')
+ results[key] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ self._random_scale(results)
+ self._resize_img(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(img_scale={self.img_scale}, '
+ f'multiscale_mode={self.multiscale_mode}, '
+ f'ratio_range={self.ratio_range}, '
+ f'keep_ratio={self.keep_ratio})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+ """Flip the image & seg.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ Args:
+ prob (float, optional): The flipping probability. Default: None.
+ direction(str, optional): The flipping direction. Options are
+ 'horizontal' and 'vertical'. Default: 'horizontal'.
+ """
+
+ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
+ def __init__(self, prob=None, direction='horizontal'):
+ self.prob = prob
+ self.direction = direction
+ if prob is not None:
+ assert prob >= 0 and prob <= 1
+ assert direction in ['horizontal', 'vertical']
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added into
+ result dict.
+ """
+
+ if 'flip' not in results:
+ flip = True if np.random.rand() < self.prob else False
+ results['flip'] = flip
+ if 'flip_direction' not in results:
+ results['flip_direction'] = self.direction
+ if results['flip']:
+ # flip image
+ results['img'] = mmcv.imflip(
+ results['img'], direction=results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ # use copy() to make numpy stride positive
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction']).copy()
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ """
+
+ def __init__(self,
+ size=None,
+ size_divisor=None,
+ pad_val=0,
+ seg_pad_val=255):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results['img'], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results['img'], self.size_divisor, pad_val=self.pad_val)
+ results['img'] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_seg(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key],
+ shape=results['pad_shape'][:2],
+ pad_val=self.seg_pad_val)
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+
+ self._pad_img(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
+ f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+
+ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
+ f'{self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rerange(object):
+ """Rerange the image pixel value.
+
+ Args:
+ min_value (float or int): Minimum value of the reranged image.
+ Default: 0.
+ max_value (float or int): Maximum value of the reranged image.
+ Default: 255.
+ """
+
+ def __init__(self, min_value=0, max_value=255):
+ assert isinstance(min_value, float) or isinstance(min_value, int)
+ assert isinstance(max_value, float) or isinstance(max_value, int)
+ assert min_value < max_value
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def __call__(self, results):
+ """Call function to rerange images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Reranged results.
+ """
+
+ img = results['img']
+ img_min_value = np.min(img)
+ img_max_value = np.max(img)
+
+ assert img_min_value < img_max_value
+ # rerange to [0, 1]
+ img = (img - img_min_value) / (img_max_value - img_min_value)
+ # rerange to [min_value, max_value]
+ img = img * (self.max_value - self.min_value) + self.min_value
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CLAHE(object):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+ """
+
+ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
+ assert isinstance(clip_limit, (float, int))
+ self.clip_limit = clip_limit
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+ self.tile_grid_size = tile_grid_size
+
+ def __call__(self, results):
+ """Call function to Use CLAHE method process images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ for i in range(results['img'].shape[2]):
+ results['img'][:, :, i] = mmcv.clahe(
+ np.array(results['img'][:, :, i], dtype=np.uint8),
+ self.clip_limit, self.tile_grid_size)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(clip_limit={self.clip_limit}, '\
+ f'tile_grid_size={self.tile_grid_size})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+ """Random crop the image & seg.
+
+ Args:
+ crop_size (tuple): Expected size after cropping, (h, w).
+ cat_max_ratio (float): The maximum ratio that single category could
+ occupy.
+ """
+
+ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ self.cat_max_ratio = cat_max_ratio
+ self.ignore_index = ignore_index
+
+ def get_crop_bbox(self, img):
+ """Randomly get a crop bounding box."""
+ margin_h = max(img.shape[0] - self.crop_size[0], 0)
+ margin_w = max(img.shape[1] - self.crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+
+ return crop_y1, crop_y2, crop_x1, crop_x2
+
+ def crop(self, img, crop_bbox):
+ """Crop from ``img``"""
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ return img
+
+ def __call__(self, results):
+ """Call function to randomly crop images, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+
+ img = results['img']
+ crop_bbox = self.get_crop_bbox(img)
+ if self.cat_max_ratio < 1.:
+ # Repeat 10 times
+ for _ in range(10):
+ seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
+ labels, cnt = np.unique(seg_temp, return_counts=True)
+ cnt = cnt[labels != self.ignore_index]
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
+ cnt) < self.cat_max_ratio:
+ break
+ crop_bbox = self.get_crop_bbox(img)
+
+ # crop the image
+ img = self.crop(img, crop_bbox)
+ img_shape = img.shape
+ results['img'] = img
+ results['img_shape'] = img_shape
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = self.crop(results[key], crop_bbox)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+
+
+@PIPELINES.register_module()
+class RandomRotate(object):
+ """Rotate the image & seg.
+
+ Args:
+ prob (float): The rotation probability.
+ degree (float, tuple[float]): Range of degrees to select from. If
+ degree is a number instead of tuple like (min, max),
+ the range of degree will be (``-degree``, ``+degree``)
+ pad_val (float, optional): Padding value of image. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used. Default: None.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image. Default: False
+ """
+
+ def __init__(self,
+ prob,
+ degree,
+ pad_val=0,
+ seg_pad_val=255,
+ center=None,
+ auto_bound=False):
+ self.prob = prob
+ assert prob >= 0 and prob <= 1
+ if isinstance(degree, (float, int)):
+ assert degree > 0, f'degree {degree} should be positive'
+ self.degree = (-degree, degree)
+ else:
+ self.degree = degree
+ assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
+ f'tuple of (min, max)'
+ self.pal_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ self.center = center
+ self.auto_bound = auto_bound
+
+ def __call__(self, results):
+ """Call function to rotate image, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+
+ rotate = True if np.random.rand() < self.prob else False
+ degree = np.random.uniform(min(*self.degree), max(*self.degree))
+ if rotate:
+ # rotate image
+ results['img'] = mmcv.imrotate(
+ results['img'],
+ angle=degree,
+ border_value=self.pal_val,
+ center=self.center,
+ auto_bound=self.auto_bound)
+
+ # rotate segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imrotate(
+ results[key],
+ angle=degree,
+ border_value=self.seg_pad_val,
+ center=self.center,
+ auto_bound=self.auto_bound,
+ interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob}, ' \
+ f'degree={self.degree}, ' \
+ f'pad_val={self.pal_val}, ' \
+ f'seg_pad_val={self.seg_pad_val}, ' \
+ f'center={self.center}, ' \
+ f'auto_bound={self.auto_bound})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RGB2Gray(object):
+ """Convert RGB image to grayscale image.
+
+ This transform calculate the weighted mean of input image channels with
+ ``weights`` and then expand the channels to ``out_channels``. When
+ ``out_channels`` is None, the number of output channels is the same as
+ input channels.
+
+ Args:
+ out_channels (int): Expected number of output channels after
+ transforming. Default: None.
+ weights (tuple[float]): The weights to calculate the weighted mean.
+ Default: (0.299, 0.587, 0.114).
+ """
+
+ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
+ assert out_channels is None or out_channels > 0
+ self.out_channels = out_channels
+ assert isinstance(weights, tuple)
+ for item in weights:
+ assert isinstance(item, (float, int))
+ self.weights = weights
+
+ def __call__(self, results):
+ """Call function to convert RGB image to grayscale image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with grayscale image.
+ """
+ img = results['img']
+ assert len(img.shape) == 3
+ assert img.shape[2] == len(self.weights)
+ weights = np.array(self.weights).reshape((1, 1, -1))
+ img = (img * weights).sum(2, keepdims=True)
+ if self.out_channels is None:
+ img = img.repeat(weights.shape[2], axis=2)
+ else:
+ img = img.repeat(self.out_channels, axis=2)
+
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(out_channels={self.out_channels}, ' \
+ f'weights={self.weights})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class AdjustGamma(object):
+ """Using gamma correction to process the image.
+
+ Args:
+ gamma (float or int): Gamma value used in gamma correction.
+ Default: 1.0.
+ """
+
+ def __init__(self, gamma=1.0):
+ assert isinstance(gamma, float) or isinstance(gamma, int)
+ assert gamma > 0
+ self.gamma = gamma
+ inv_gamma = 1.0 / gamma
+ self.table = np.array([(i / 255.0)**inv_gamma * 255
+ for i in np.arange(256)]).astype('uint8')
+
+ def __call__(self, results):
+ """Call function to process the image with gamma correction.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ results['img'] = mmcv.lut_transform(
+ np.array(results['img'], dtype=np.uint8), self.table)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(gamma={self.gamma})'
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ """
+
+ def __init__(self, scale_factor=1):
+ self.scale_factor = scale_factor
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key], self.scale_factor, interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def convert(self, img, alpha=1, beta=0):
+ """Multiple with alpha and add beat with clip."""
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+
+ def brightness(self, img):
+ """Brightness distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ beta=random.uniform(-self.brightness_delta,
+ self.brightness_delta))
+ return img
+
+ def contrast(self, img):
+ """Contrast distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
+ return img
+
+ def saturation(self, img):
+ """Saturation distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :, 1] = self.convert(
+ img[:, :, 1],
+ alpha=random.uniform(self.saturation_lower,
+ self.saturation_upper))
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def hue(self, img):
+ """Hue distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :,
+ 0] = (img[:, :, 0].astype(int) +
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ img = results['img']
+ # random brightness
+ img = self.brightness(img)
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ img = self.contrast(img)
+
+ # random saturation
+ img = self.saturation(img)
+
+ # random hue
+ img = self.hue(img)
+
+ # random contrast
+ if mode == 0:
+ img = self.contrast(img)
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
+ f'contrast_range=({self.contrast_lower}, '
+ f'{self.contrast_upper}), '
+ f'saturation_range=({self.saturation_lower}, '
+ f'{self.saturation_upper}), '
+ f'hue_delta={self.hue_delta})')
+ return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/stare.py b/annotator/uniformer/mmseg/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/stare.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class STAREDataset(CustomDataset):
+ """STARE dataset.
+
+ In segmentation map annotation for STARE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.ah.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(STAREDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.ah.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/voc.py b/annotator/uniformer/mmseg/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/voc.py
@@ -0,0 +1,29 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalVOCDataset(CustomDataset):
+ """Pascal VOC dataset.
+
+ Args:
+ split (str): Split txt file for Pascal VOC.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
+ 'train', 'tvmonitor')
+
+ PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalVOCDataset, self).__init__(
+ img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/annotator/uniformer/mmseg/models/__init__.py b/annotator/uniformer/mmseg/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/__init__.py
@@ -0,0 +1,12 @@
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
+ build_head, build_loss, build_segmentor)
+from .decode_heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .segmentors import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
+ 'build_head', 'build_loss', 'build_segmentor'
+]
diff --git a/annotator/uniformer/mmseg/models/backbones/__init__.py b/annotator/uniformer/mmseg/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8339983905fb5d20bae42ba6f76fea75d278b1aa
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/__init__.py
@@ -0,0 +1,17 @@
+from .cgnet import CGNet
+# from .fast_scnn import FastSCNN
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnext import ResNeXt
+from .unet import UNet
+from .vit import VisionTransformer
+from .uniformer import UniFormer
+
+__all__ = [
+ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
+ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
+ 'VisionTransformer', 'UniFormer'
+]
diff --git a/annotator/uniformer/mmseg/models/backbones/cgnet.py b/annotator/uniformer/mmseg/models/backbones/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8bca442c8f18179f217e40c298fb5ef39df77c4
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/cgnet.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
+ constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class GlobalContextExtractor(nn.Module):
+ """Global Context Extractor for CGNet.
+
+ This class is employed to refine the joint feature of both local feature
+ and surrounding context.
+
+ Args:
+ channel (int): Number of input feature channels.
+ reduction (int): Reductions for global context extractor. Default: 16.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self, channel, reduction=16, with_cp=False):
+ super(GlobalContextExtractor, self).__init__()
+ self.channel = channel
+ self.reduction = reduction
+ assert reduction >= 1 and channel >= reduction
+ self.with_cp = with_cp
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ num_batch, num_channel = x.size()[:2]
+ y = self.avg_pool(x).view(num_batch, num_channel)
+ y = self.fc(y).view(num_batch, num_channel, 1, 1)
+ return x * y
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class ContextGuidedBlock(nn.Module):
+ """Context Guided Block for CGNet.
+
+ This class consists of four components: local feature extractor,
+ surrounding feature extractor, joint feature extractor and global
+ context extractor.
+
+ Args:
+ in_channels (int): Number of input feature channels.
+ out_channels (int): Number of output feature channels.
+ dilation (int): Dilation rate for surrounding context extractor.
+ Default: 2.
+ reduction (int): Reduction for global context extractor. Default: 16.
+ skip_connect (bool): Add input to output or not. Default: True.
+ downsample (bool): Downsample the input to 1/2 or not. Default: False.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ dilation=2,
+ reduction=16,
+ skip_connect=True,
+ downsample=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ with_cp=False):
+ super(ContextGuidedBlock, self).__init__()
+ self.with_cp = with_cp
+ self.downsample = downsample
+
+ channels = out_channels if downsample else out_channels // 2
+ if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
+ act_cfg['num_parameters'] = channels
+ kernel_size = 3 if downsample else 1
+ stride = 2 if downsample else 1
+ padding = (kernel_size - 1) // 2
+
+ self.conv1x1 = ConvModule(
+ in_channels,
+ channels,
+ kernel_size,
+ stride,
+ padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.f_loc = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ groups=channels,
+ bias=False)
+ self.f_sur = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=dilation,
+ groups=channels,
+ dilation=dilation,
+ bias=False)
+
+ self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
+ self.activate = nn.PReLU(2 * channels)
+
+ if downsample:
+ self.bottleneck = build_conv_layer(
+ conv_cfg,
+ 2 * channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+
+ self.skip_connect = skip_connect and not downsample
+ self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = self.conv1x1(x)
+ loc = self.f_loc(out)
+ sur = self.f_sur(out)
+
+ joi_feat = torch.cat([loc, sur], 1) # the joint feature
+ joi_feat = self.bn(joi_feat)
+ joi_feat = self.activate(joi_feat)
+ if self.downsample:
+ joi_feat = self.bottleneck(joi_feat) # channel = out_channels
+ # f_glo is employed to refine the joint feature
+ out = self.f_glo(joi_feat)
+
+ if self.skip_connect:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InputInjection(nn.Module):
+ """Downsampling module for CGNet."""
+
+ def __init__(self, num_downsampling):
+ super(InputInjection, self).__init__()
+ self.pool = nn.ModuleList()
+ for i in range(num_downsampling):
+ self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
+
+ def forward(self, x):
+ for pool in self.pool:
+ x = pool(x)
+ return x
+
+
+@BACKBONES.register_module()
+class CGNet(nn.Module):
+ """CGNet backbone.
+
+ A Light-weight Context Guided Network for Semantic Segmentation
+ arXiv: https://arxiv.org/abs/1811.08201
+
+ Args:
+ in_channels (int): Number of input image channels. Normally 3.
+ num_channels (tuple[int]): Numbers of feature channels at each stages.
+ Default: (32, 64, 128).
+ num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
+ Default: (3, 21).
+ dilations (tuple[int]): Dilation rate for surrounding context
+ extractors at stage 1 and stage 2. Default: (2, 4).
+ reductions (tuple[int]): Reductions for global context extractors at
+ stage 1 and stage 2. Default: (8, 16).
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ norm_eval=False,
+ with_cp=False):
+
+ super(CGNet, self).__init__()
+ self.in_channels = in_channels
+ self.num_channels = num_channels
+ assert isinstance(self.num_channels, tuple) and len(
+ self.num_channels) == 3
+ self.num_blocks = num_blocks
+ assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
+ self.dilations = dilations
+ assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
+ self.reductions = reductions
+ assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
+ self.act_cfg['num_parameters'] = num_channels[0]
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ cur_channels = in_channels
+ self.stem = nn.ModuleList()
+ for i in range(3):
+ self.stem.append(
+ ConvModule(
+ cur_channels,
+ num_channels[0],
+ 3,
+ 2 if i == 0 else 1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ cur_channels = num_channels[0]
+
+ self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
+ self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
+
+ cur_channels += in_channels
+ self.norm_prelu_0 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 1
+ self.level1 = nn.ModuleList()
+ for i in range(num_blocks[0]):
+ self.level1.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[1],
+ num_channels[1],
+ dilations[0],
+ reductions[0],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[1] + in_channels
+ self.norm_prelu_1 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 2
+ self.level2 = nn.ModuleList()
+ for i in range(num_blocks[1]):
+ self.level2.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[2],
+ num_channels[2],
+ dilations[1],
+ reductions[1],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[2]
+ self.norm_prelu_2 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ def forward(self, x):
+ output = []
+
+ # stage 0
+ inp_2x = self.inject_2x(x)
+ inp_4x = self.inject_4x(x)
+ for layer in self.stem:
+ x = layer(x)
+ x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
+ output.append(x)
+
+ # stage 1
+ for i, layer in enumerate(self.level1):
+ x = layer(x)
+ if i == 0:
+ down1 = x
+ x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
+ output.append(x)
+
+ # stage 2
+ for i, layer in enumerate(self.level2):
+ x = layer(x)
+ if i == 0:
+ down2 = x
+ x = self.norm_prelu_2(torch.cat([down2, x], 1))
+ output.append(x)
+
+ return output
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ elif isinstance(m, nn.PReLU):
+ constant_init(m, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(CGNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/fast_scnn.py b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c2350177cbc2066f45add568d30eb6041f74f3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
+ kaiming_init)
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from annotator.uniformer.mmseg.models.decode_heads.psp_head import PPM
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import BACKBONES
+from ..utils.inverted_residual import InvertedResidual
+
+
+class LearningToDownsample(nn.Module):
+ """Learning to downsample module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ dw_channels (tuple[int]): Number of output channels of the first and
+ the second depthwise conv (dwconv) layers.
+ out_channels (int): Number of output channels of the whole
+ 'learning to downsample' module.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU')):
+ super(LearningToDownsample, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ dw_channels1 = dw_channels[0]
+ dw_channels2 = dw_channels[1]
+
+ self.conv = ConvModule(
+ in_channels,
+ dw_channels1,
+ 3,
+ stride=2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.dsconv1 = DepthwiseSeparableConvModule(
+ dw_channels1,
+ dw_channels2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ self.dsconv2 = DepthwiseSeparableConvModule(
+ dw_channels2,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.dsconv1(x)
+ x = self.dsconv2(x)
+ return x
+
+
+class GlobalFeatureExtractor(nn.Module):
+ """Global feature extractor module.
+
+ Args:
+ in_channels (int): Number of input channels of the GFE module.
+ Default: 64
+ block_channels (tuple[int]): Tuple of ints. Each int specifies the
+ number of output channels of each Inverted Residual module.
+ Default: (64, 96, 128)
+ out_channels(int): Number of output channels of the GFE module.
+ Default: 128
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ Default: 6
+ num_blocks (tuple[int]): Tuple of ints. Each int specifies the
+ number of times each Inverted Residual module is repeated.
+ The repeated Inverted Residual modules are called a 'group'.
+ Default: (3, 3, 3)
+ strides (tuple[int]): Tuple of ints. Each int specifies
+ the downsampling factor of each 'group'.
+ Default: (2, 2, 1)
+ pool_scales (tuple[int]): Tuple of ints. Each int specifies
+ the parameter required in 'global average pooling' within PPM.
+ Default: (1, 2, 3, 6)
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=64,
+ block_channels=(64, 96, 128),
+ out_channels=128,
+ expand_ratio=6,
+ num_blocks=(3, 3, 3),
+ strides=(2, 2, 1),
+ pool_scales=(1, 2, 3, 6),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(GlobalFeatureExtractor, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ assert len(block_channels) == len(num_blocks) == 3
+ self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
+ num_blocks[0], strides[0],
+ expand_ratio)
+ self.bottleneck2 = self._make_layer(block_channels[0],
+ block_channels[1], num_blocks[1],
+ strides[1], expand_ratio)
+ self.bottleneck3 = self._make_layer(block_channels[1],
+ block_channels[2], num_blocks[2],
+ strides[2], expand_ratio)
+ self.ppm = PPM(
+ pool_scales,
+ block_channels[2],
+ block_channels[2] // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=align_corners)
+ self.out = ConvModule(
+ block_channels[2] * 2,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def _make_layer(self,
+ in_channels,
+ out_channels,
+ blocks,
+ stride=1,
+ expand_ratio=6):
+ layers = [
+ InvertedResidual(
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ norm_cfg=self.norm_cfg)
+ ]
+ for i in range(1, blocks):
+ layers.append(
+ InvertedResidual(
+ out_channels,
+ out_channels,
+ 1,
+ expand_ratio,
+ norm_cfg=self.norm_cfg))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.bottleneck1(x)
+ x = self.bottleneck2(x)
+ x = self.bottleneck3(x)
+ x = torch.cat([x, *self.ppm(x)], dim=1)
+ x = self.out(x)
+ return x
+
+
+class FeatureFusionModule(nn.Module):
+ """Feature fusion module.
+
+ Args:
+ higher_in_channels (int): Number of input channels of the
+ higher-resolution branch.
+ lower_in_channels (int): Number of input channels of the
+ lower-resolution branch.
+ out_channels (int): Number of output channels.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ higher_in_channels,
+ lower_in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(FeatureFusionModule, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.dwconv = ConvModule(
+ lower_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.conv_lower_res = ConvModule(
+ out_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.conv_higher_res = ConvModule(
+ higher_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.relu = nn.ReLU(True)
+
+ def forward(self, higher_res_feature, lower_res_feature):
+ lower_res_feature = resize(
+ lower_res_feature,
+ size=higher_res_feature.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ lower_res_feature = self.dwconv(lower_res_feature)
+ lower_res_feature = self.conv_lower_res(lower_res_feature)
+
+ higher_res_feature = self.conv_higher_res(higher_res_feature)
+ out = higher_res_feature + lower_res_feature
+ return self.relu(out)
+
+
+@BACKBONES.register_module()
+class FastSCNN(nn.Module):
+ """Fast-SCNN Backbone.
+
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ downsample_dw_channels (tuple[int]): Number of output channels after
+ the first conv layer & the second conv layer in
+ Learning-To-Downsample (LTD) module.
+ Default: (32, 48).
+ global_in_channels (int): Number of input channels of
+ Global Feature Extractor(GFE).
+ Equal to number of output channels of LTD.
+ Default: 64.
+ global_block_channels (tuple[int]): Tuple of integers that describe
+ the output channels for each of the MobileNet-v2 bottleneck
+ residual blocks in GFE.
+ Default: (64, 96, 128).
+ global_block_strides (tuple[int]): Tuple of integers
+ that describe the strides (downsampling factors) for each of the
+ MobileNet-v2 bottleneck residual blocks in GFE.
+ Default: (2, 2, 1).
+ global_out_channels (int): Number of output channels of GFE.
+ Default: 128.
+ higher_in_channels (int): Number of input channels of the higher
+ resolution branch in FFM.
+ Equal to global_in_channels.
+ Default: 64.
+ lower_in_channels (int): Number of input channels of the lower
+ resolution branch in FFM.
+ Equal to global_out_channels.
+ Default: 128.
+ fusion_out_channels (int): Number of output channels of FFM.
+ Default: 128.
+ out_indices (tuple): Tuple of indices of list
+ [higher_res_features, lower_res_features, fusion_output].
+ Often set to (0,1,2) to enable aux. heads.
+ Default: (0, 1, 2).
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=3,
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+
+ super(FastSCNN, self).__init__()
+ if global_in_channels != higher_in_channels:
+ raise AssertionError('Global Input Channels must be the same \
+ with Higher Input Channels!')
+ elif global_out_channels != lower_in_channels:
+ raise AssertionError('Global Output Channels must be the same \
+ with Lower Input Channels!')
+
+ self.in_channels = in_channels
+ self.downsample_dw_channels1 = downsample_dw_channels[0]
+ self.downsample_dw_channels2 = downsample_dw_channels[1]
+ self.global_in_channels = global_in_channels
+ self.global_block_channels = global_block_channels
+ self.global_block_strides = global_block_strides
+ self.global_out_channels = global_out_channels
+ self.higher_in_channels = higher_in_channels
+ self.lower_in_channels = lower_in_channels
+ self.fusion_out_channels = fusion_out_channels
+ self.out_indices = out_indices
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.learning_to_downsample = LearningToDownsample(
+ in_channels,
+ downsample_dw_channels,
+ global_in_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.global_feature_extractor = GlobalFeatureExtractor(
+ global_in_channels,
+ global_block_channels,
+ global_out_channels,
+ strides=self.global_block_strides,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.feature_fusion = FeatureFusionModule(
+ higher_in_channels,
+ lower_in_channels,
+ fusion_out_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+
+ def init_weights(self, pretrained=None):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ def forward(self, x):
+ higher_res_features = self.learning_to_downsample(x)
+ lower_res_features = self.global_feature_extractor(higher_res_features)
+ fusion_output = self.feature_fusion(higher_res_features,
+ lower_res_features)
+
+ outs = [higher_res_features, lower_res_features, fusion_output]
+ outs = [outs[i] for i in self.out_indices]
+ return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/backbones/hrnet.py b/annotator/uniformer/mmseg/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..331ebf3ccb8597b3f507670753789073fc3c946d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/hrnet.py
@@ -0,0 +1,555 @@
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.ops import Upsample, resize
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(nn.Module):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HRModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ """Check branches configuration."""
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
+ f'{len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
+ f'{len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
+ f'{len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Build one branch."""
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ """Build multiple branch."""
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ """Build fuse layer."""
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ # we set align_corners=False for HRNet
+ Upsample(
+ scale_factor=2**(j - i),
+ mode='bilinear',
+ align_corners=False)))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ elif j > i:
+ y = y + resize(
+ self.fuse_layers[i][j](x[j]),
+ size=x[i].shape[2:],
+ mode='bilinear',
+ align_corners=False)
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(nn.Module):
+ """HRNet backbone.
+
+ High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: https://arxiv.org/abs/1904.04514
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Normally 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.uniformer.mmseg.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False):
+ super(HRNet, self).__init__()
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ """Make each layer."""
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ """Make each stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*hr_modules), in_channels
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6b3791692a0d1b5da3601875711710b7bd01ba
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
@@ -0,0 +1,180 @@
+import logging
+
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+
+
+@BACKBONES.register_module()
+class MobileNetV2(nn.Module):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ strides (Sequence[int], optional): Strides of the first block of each
+ layer. If not specified, default config in ``arch_setting`` will
+ be used.
+ dilations (Sequence[int]): Dilation of each layer.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ # Parameters to build layers. 3 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks.
+ arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
+ [6, 96, 3], [6, 160, 3], [6, 320, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ strides=(1, 2, 2, 2, 1, 2, 1),
+ dilations=(1, 1, 1, 1, 1, 1, 1),
+ out_indices=(1, 2, 4, 6),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV2, self).__init__()
+ self.widen_factor = widen_factor
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == len(self.arch_settings)
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 7):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+
+ if frozen_stages not in range(-1, 7):
+ raise ValueError('frozen_stages must be in range(-1, 7). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks = layer_cfg
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ def make_layer(self, out_channels, num_blocks, stride, dilation,
+ expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): Number of blocks.
+ stride (int): Stride of the first block.
+ dilation (int): Dilation of the first block.
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio.
+ """
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride if i == 0 else 1,
+ expand_ratio=expand_ratio,
+ dilation=dilation if i == 0 else 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..16817400b4102899794fe64c9644713a4e54e2f9
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
@@ -0,0 +1,255 @@
+import logging
+
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.cnn.bricks import Conv2dAdaptivePadding
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidualV3 as InvertedResidual
+
+
+@BACKBONES.register_module()
+class MobileNetV3(nn.Module):
+ """MobileNetV3 backbone.
+
+ This backbone is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
+ Default: 'small'.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (tuple[int]): Output from which layer.
+ Default: (0, 1, 12).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
+ [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
+ [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
+ [5, 960, 160, True, 'HSwish', 1],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ out_indices=(0, 1, 12),
+ frozen_stages=-1,
+ reduction_factor=1,
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV3, self).__init__()
+ assert arch in self.arch_settings
+ assert isinstance(reduction_factor, int) and reduction_factor > 0
+ assert mmcv.is_tuple_of(out_indices, int)
+ for index in out_indices:
+ if index not in range(0, len(self.arch_settings[arch]) + 2):
+ raise ValueError(
+ 'the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch])+2}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch])+2}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.reduction_factor = reduction_factor
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.layers = self._make_layer()
+
+ def _make_layer(self):
+ layers = []
+
+ # build the first layer (layer0)
+ in_channels = 16
+ layer = ConvModule(
+ in_channels=3,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ self.add_module('layer0', layer)
+ layers.append('layer0')
+
+ layer_setting = self.arch_settings[self.arch]
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+
+ if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
+ i >= 8:
+ mid_channels = mid_channels // self.reduction_factor
+ out_channels = out_channels // self.reduction_factor
+
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0)))
+ else:
+ se_cfg = None
+
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=(in_channels != mid_channels),
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ in_channels = out_channels
+ layer_name = 'layer{}'.format(i + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # build the last layer
+ # block5 layer12 os=32 for small model
+ # block6 layer16 os=32 for large model
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=576 if self.arch == 'small' else 960,
+ kernel_size=1,
+ stride=1,
+ dilation=4,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ layer_name = 'layer{}'.format(len(layer_setting) + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # next, convert backbone MobileNetV3 to a semantic segmentation version
+ if self.arch == 'small':
+ self.layer4.depthwise_conv.conv.stride = (1, 1)
+ self.layer9.depthwise_conv.conv.stride = (1, 1)
+ for i in range(4, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 9:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+ else:
+ self.layer7.depthwise_conv.conv.stride = (1, 1)
+ self.layer13.depthwise_conv.conv.stride = (1, 1)
+ for i in range(7, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 13:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+
+ return layers
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV3, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/resnest.py b/annotator/uniformer/mmseg/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45a837f395230029e9d4194ff9f7f2f8f7067b0
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnest.py
@@ -0,0 +1,314 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d in ResNeSt.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/resnet.py b/annotator/uniformer/mmseg/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e52bf048d28ecb069db4728e5f05ad85ac53198
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnet.py
@@ -0,0 +1,688 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+ constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(nn.Module):
+ """Basic block for ResNet."""
+
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(BasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ """Forward function for plugins."""
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default" 3.
+ stem_channels (int): Number of stem channels. Default: 64.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+
+ - position (str, required): Position inside block to insert plugin,
+ options: 'after_conv1', 'after_conv2', 'after_conv3'.
+
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'
+ multi_grid (Sequence[int]|None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.uniformer.mmseg.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ multi_grid=None,
+ contract_dilation=False,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.multi_grid = multi_grid
+ self.contract_dilation = contract_dilation
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ # multi grid is applied to last layer only
+ stage_multi_grid = multi_grid if i == len(
+ self.stage_blocks) - 1 else None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ multi_grid=stage_multi_grid,
+ contract_dilation=contract_dilation)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i+1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """make plugins for ResNet 'stage_idx'th stage .
+
+ Currently we support to insert 'context_block',
+ 'empirical_attention_block', 'nonlocal_block' into the backbone like
+ ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+
+ An example of plugins format could be :
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+
+ Suppose 'stage_idx=0', the structure of blocks in the stage would be:
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+ If stages is missing, the plugin would be applied to all stages.
+
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ """Make stem layer for ResNet."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze stages param and norm stats."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m, 'conv2_offset'):
+ constant_init(m.conv2_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1c(ResNet):
+ """ResNetV1c variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+ in the input stem with three 3x3 convs.
+
+ References:
+ .. [1] https://arxiv.org/pdf/1812.01187.pdf
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1c, self).__init__(
+ deep_stem=True, avg_down=False, **kwargs)
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+ """ResNetV1d variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/resnext.py b/annotator/uniformer/mmseg/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..962249ad6fd9b50960ad6426f7ce3cac6ed8c5bc
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnext.py
@@ -0,0 +1,145 @@
+import math
+
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Normally 3.
+ num_stages (int): Resnet stages, normally 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.uniformer.mmseg.models import ResNeXt
+ >>> import torch
+ >>> self = ResNeXt(depth=50)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/unet.py b/annotator/uniformer/mmseg/models/backbones/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..82caa16a94c195c192a2a920fb7bc7e60f0f3ce3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/unet.py
@@ -0,0 +1,429 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import UpConvBlock
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super(BasicConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super(DeconvModule, self).__init__()
+
+ assert (kernel_size - scale_factor >= 0) and\
+ (kernel_size - scale_factor) % 2 == 0,\
+ f'kernel_size should be greater than or equal to scale_factor '\
+ f'and (kernel_size - scale_factor) should be even numbers, '\
+ f'while the kernel size is {kernel_size} and scale_factor is '\
+ f'{scale_factor}.'
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super(InterpConv, self).__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+@BACKBONES.register_module()
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(UNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(UNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (h % whole_downsample_rate == 0) \
+ and (w % whole_downsample_rate == 0),\
+ f'The input image size {(h, w)} should be divisible by the whole '\
+ f'downsample rate {whole_downsample_rate}, when num_stages is '\
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '\
+ f'is {self.downsamples}.'
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/annotator/uniformer/mmseg/models/backbones/uniformer.py b/annotator/uniformer/mmseg/models/backbones/uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4bb88e4c928540cca9ab609988b916520f5b7a
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/uniformer.py
@@ -0,0 +1,422 @@
+# --------------------------------------------------------
+# UniFormer
+# Copyright (c) 2022 SenseTime X-Lab
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Kunchang Li
+# --------------------------------------------------------
+
+from collections import OrderedDict
+import math
+
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from annotator.uniformer.mmcv_custom import load_checkpoint
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class CMlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class CBlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.conv1 = nn.Conv2d(dim, dim, 1)
+ self.conv2 = nn.Conv2d(dim, dim, 1)
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SABlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ B, N, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.transpose(1, 2).reshape(B, N, H, W)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SABlock_Windows(nn.Module):
+ def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.window_size=window_size
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x.permute(0, 2, 3, 1)
+ B, H, W, C = x.shape
+ shortcut = x
+ x = self.norm1(x)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.norm = nn.LayerNorm(embed_dim)
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ x = self.proj(x)
+ B, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ return x
+
+
+@BACKBONES.register_module()
+class UniFormer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512],
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0],
+ windows=False, hybrid=False, window_size=14):
+ """
+ Args:
+ layer (list): number of block in each layer
+ img_size (int, tuple): input image size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ head_dim (int): dimension of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer (nn.Module): normalization layer
+ pretrained_path (str): path of pretrained model
+ use_checkpoint (bool): whether use checkpoint
+ checkpoint_num (list): index for using checkpoint in every stage
+ windows (bool): whether use window MHRA
+ hybrid (bool): whether use hybrid MHRA
+ window_size (int): size of window (>14)
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.checkpoint_num = checkpoint_num
+ self.windows = windows
+ print(f'Use Checkpoint: {self.use_checkpoint}')
+ print(f'Checkpoint Number: {self.checkpoint_num}')
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
+ self.patch_embed2 = PatchEmbed(
+ img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
+ self.patch_embed3 = PatchEmbed(
+ img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
+ self.patch_embed4 = PatchEmbed(
+ img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule
+ num_heads = [dim // head_dim for dim in embed_dim]
+ self.blocks1 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(layers[0])])
+ self.norm1=norm_layer(embed_dim[0])
+ self.blocks2 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer)
+ for i in range(layers[1])])
+ self.norm2 = norm_layer(embed_dim[1])
+ if self.windows:
+ print('Use local window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ elif hybrid:
+ print('Use hybrid window for blocks in stage3')
+ block3 = []
+ for i in range(layers[2]):
+ if (i + 1) % 4 == 0:
+ block3.append(SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ else:
+ block3.append(SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ self.blocks3 = nn.ModuleList(block3)
+ else:
+ print('Use global window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ self.norm3 = norm_layer(embed_dim[2])
+ self.blocks4 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer)
+ for i in range(layers[3])])
+ self.norm4 = norm_layer(embed_dim[3])
+
+ # Representation layer
+ if representation_size:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ self.apply(self._init_weights)
+ self.init_weights(pretrained=pretrained_path)
+
+ def init_weights(self, pretrained):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
+ print(f'Load pretrained model from {pretrained}')
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ out = []
+ x = self.patch_embed1(x)
+ x = self.pos_drop(x)
+ for i, blk in enumerate(self.blocks1):
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm1(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed2(x)
+ for i, blk in enumerate(self.blocks2):
+ if self.use_checkpoint and i < self.checkpoint_num[1]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm2(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed3(x)
+ for i, blk in enumerate(self.blocks3):
+ if self.use_checkpoint and i < self.checkpoint_num[2]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm3(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed4(x)
+ for i, blk in enumerate(self.blocks4):
+ if self.use_checkpoint and i < self.checkpoint_num[3]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm4(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ return tuple(out)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
diff --git a/annotator/uniformer/mmseg/models/backbones/vit.py b/annotator/uniformer/mmseg/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e4479650690e08cbc4cab9427aefda47c2116d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/vit.py
@@ -0,0 +1,459 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/vision_transformer.py."""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
+ constant_init, kaiming_init, normal_init)
+from annotator.uniformer.mmcv.runner import _load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import DropPath, trunc_normal_
+
+
+class Mlp(nn.Module):
+ """MLP layer for Encoder block.
+
+ Args:
+ in_features(int): Input dimension for the first fully
+ connected layer.
+ hidden_features(int): Output dimension for the first fully
+ connected layer.
+ out_features(int): Output dementsion for the second fully
+ connected layer.
+ act_cfg(dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ drop(float): Drop rate for the dropout layer. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ """Attention layer for Encoder block.
+
+ Args:
+ dim (int): Dimension for the input vector.
+ num_heads (int): Number of parallel attention heads.
+ qkv_bias (bool): Enable bias for qkv if True. Default: False.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for output weights. Default: 0.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ b, n, c = x.shape
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
+ c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ """Implements encoder block with residual connection.
+
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Number of parallel attention heads.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float): Drop rate for mlp output weights. Default: 0.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for attn layer output weights.
+ Default: 0.
+ drop_path (float): Drop rate for paths of model.
+ Default: 0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', requires_grad=True).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ proj_drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ with_cp=False):
+ super(Block, self).__init__()
+ self.with_cp = with_cp
+ _, self.norm1 = build_norm_layer(norm_cfg, dim)
+ self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
+ proj_drop)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ _, self.norm2 = build_norm_layer(norm_cfg, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x + self.drop_path(self.attn(self.norm1(x)))
+ out = out + self.drop_path(self.mlp(self.norm2(out)))
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding.
+
+ Args:
+ img_size (int | tuple): Input image size.
+ default: 224.
+ patch_size (int): Width and height for a patch.
+ default: 16.
+ in_channels (int): Input channels for images. Default: 3.
+ embed_dim (int): The embedding dimension. Default: 768.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768):
+ super(PatchEmbed, self).__init__()
+ if isinstance(img_size, int):
+ self.img_size = (img_size, img_size)
+ elif isinstance(img_size, tuple):
+ self.img_size = img_size
+ else:
+ raise TypeError('img_size must be type of int or tuple')
+ h, w = self.img_size
+ self.patch_size = (patch_size, patch_size)
+ self.num_patches = (h // patch_size) * (w // patch_size)
+ self.proj = Conv2d(
+ in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ return self.proj(x).flatten(2).transpose(1, 2)
+
+
+@BACKBONES.register_module()
+class VisionTransformer(nn.Module):
+ """Vision transformer backbone.
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+ Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+
+ Args:
+ img_size (tuple): input image size. Default: (224, 224).
+ patch_size (int, tuple): patch size. Default: 16.
+ in_channels (int): number of input channels. Default: 3.
+ embed_dim (int): embedding dimension. Default: 768.
+ depth (int): depth of transformer. Default: 12.
+ num_heads (int): number of attention heads. Default: 12.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): dropout rate. Default: 0.
+ attn_drop_rate (float): attention dropout rate. Default: 0.
+ drop_path_rate (float): Rate of DropPath. Default: 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Default: bicubic.
+ with_cls_token (bool): If concatenating class token into image tokens
+ as transformer input. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ """
+
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=11,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ final_norm=False,
+ with_cls_token=True,
+ interpolate_mode='bicubic',
+ with_cp=False):
+ super(VisionTransformer, self).__init__()
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.features = self.embed_dim = embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim)
+
+ self.with_cls_token = with_cls_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if isinstance(out_indices, int):
+ self.out_indices = [out_indices]
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+ self.out_indices = out_indices
+ else:
+ raise TypeError('out_indices must be type of int, list or tuple')
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=dpr[i],
+ attn_drop=attn_drop_rate,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp) for i in range(depth)
+ ])
+
+ self.interpolate_mode = interpolate_mode
+ self.final_norm = final_norm
+ if final_norm:
+ _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(pretrained, logger=logger)
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ if 'pos_embed' in state_dict.keys():
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
+ logger.info(msg=f'Resize the pos_embed shape from \
+{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+ h, w = self.img_size
+ pos_size = int(
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+ state_dict['pos_embed'] = self.resize_pos_embed(
+ state_dict['pos_embed'], (h, w), (pos_size, pos_size),
+ self.patch_size, self.interpolate_mode)
+
+ self.load_state_dict(state_dict, False)
+
+ elif pretrained is None:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'mlp' in n:
+ normal_init(m.bias, std=1e-6)
+ else:
+ constant_init(m.bias, 0)
+ elif isinstance(m, Conv2d):
+ kaiming_init(m.weight, mode='fan_in')
+ if m.bias is not None:
+ constant_init(m.bias, 0)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m.bias, 0)
+ constant_init(m.weight, 1.0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def _pos_embeding(self, img, patched_img, pos_embed):
+ """Positiong embeding method.
+
+ Resize the pos_embed, if the input image size doesn't match
+ the training size.
+ Args:
+ img (torch.Tensor): The inference image tensor, the shape
+ must be [B, C, H, W].
+ patched_img (torch.Tensor): The patched image, it should be
+ shape of [B, L1, C].
+ pos_embed (torch.Tensor): The pos_embed weighs, it should be
+ shape of [B, L2, c].
+ Return:
+ torch.Tensor: The pos encoded image feature.
+ """
+ assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
+ 'the shapes of patched_img and pos_embed must be [B, L, C]'
+ x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
+ if x_len != pos_len:
+ if pos_len == (self.img_size[0] // self.patch_size) * (
+ self.img_size[1] // self.patch_size) + 1:
+ pos_h = self.img_size[0] // self.patch_size
+ pos_w = self.img_size[1] // self.patch_size
+ else:
+ raise ValueError(
+ 'Unexpected shape of pos_embed, got {}.'.format(
+ pos_embed.shape))
+ pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
+ (pos_h, pos_w), self.patch_size,
+ self.interpolate_mode)
+ return self.pos_drop(patched_img + pos_embed)
+
+ @staticmethod
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bicubic interpolate method.
+ Args:
+ pos_embed (torch.Tensor): pos_embed weights.
+ input_shpae (tuple): Tuple for (input_h, intput_w).
+ pos_shape (tuple): Tuple for (pos_h, pos_w).
+ patch_size (int): Patch size.
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ input_h, input_w = input_shpae
+ pos_h, pos_w = pos_shape
+ cls_token_weight = pos_embed[:, 0]
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight,
+ size=[input_h // patch_size, input_w // patch_size],
+ align_corners=False,
+ mode=mode)
+ cls_token_weight = cls_token_weight.unsqueeze(1)
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+ return pos_embed
+
+ def forward(self, inputs):
+ B = inputs.shape[0]
+
+ x = self.patch_embed(inputs)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self._pos_embeding(inputs, x, self.pos_embed)
+
+ if not self.with_cls_token:
+ # Remove class token for transformer input
+ x = x[:, 1:]
+
+ outs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i == len(self.blocks) - 1:
+ if self.final_norm:
+ x = self.norm(x)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ else:
+ out = x
+ B, _, C = out.shape
+ out = out.reshape(B, inputs.shape[2] // self.patch_size,
+ inputs.shape[3] // self.patch_size,
+ C).permute(0, 3, 1, 2)
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VisionTransformer, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.LayerNorm):
+ m.eval()
diff --git a/annotator/uniformer/mmseg/models/builder.py b/annotator/uniformer/mmseg/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5b971252bfc971c3ffbaa27746d69b1d3ea9fd
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/builder.py
@@ -0,0 +1,46 @@
+import warnings
+
+from annotator.uniformer.mmcv.cnn import MODELS as MMCV_MODELS
+from annotator.uniformer.mmcv.utils import Registry
+
+MODELS = Registry('models', parent=MMCV_MODELS)
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+SEGMENTORS = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_segmentor(cfg, train_cfg=None, test_cfg=None):
+ """Build segmentor."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return SEGMENTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/annotator/uniformer/mmseg/models/decode_heads/__init__.py b/annotator/uniformer/mmseg/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/__init__.py
@@ -0,0 +1,28 @@
+from .ann_head import ANNHead
+from .apc_head import APCHead
+from .aspp_head import ASPPHead
+from .cc_head import CCHead
+from .da_head import DAHead
+from .dm_head import DMHead
+from .dnl_head import DNLHead
+from .ema_head import EMAHead
+from .enc_head import EncHead
+from .fcn_head import FCNHead
+from .fpn_head import FPNHead
+from .gc_head import GCHead
+from .lraspp_head import LRASPPHead
+from .nl_head import NLHead
+from .ocr_head import OCRHead
+# from .point_head import PointHead
+from .psa_head import PSAHead
+from .psp_head import PSPHead
+from .sep_aspp_head import DepthwiseSeparableASPPHead
+from .sep_fcn_head import DepthwiseSeparableFCNHead
+from .uper_head import UPerHead
+
+__all__ = [
+ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
+ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
+ 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
+ 'APCHead', 'DMHead', 'LRASPPHead'
+]
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ann_head.py b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..30aaacc2cafc568d3de71d1477b4de0dc0fea9d3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PPMConcat(nn.ModuleList):
+ """Pyramid Pooling Module that only concat the features of each layer.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ """
+
+ def __init__(self, pool_scales=(1, 3, 6, 8)):
+ super(PPMConcat, self).__init__(
+ [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
+
+ def forward(self, feats):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(feats)
+ ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
+ concat_outs = torch.cat(ppm_outs, dim=2)
+ return concat_outs
+
+
+class SelfAttentionBlock(_SelfAttentionBlock):
+ """Make a ANN used SelfAttentionBlock.
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_scale (int): The scale of query feature map.
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, share_key_query, query_scale, key_pool_scales,
+ conv_cfg, norm_cfg, act_cfg):
+ key_psp = PPMConcat(key_pool_scales)
+ if query_scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=query_scale)
+ else:
+ query_downsample = None
+ super(SelfAttentionBlock, self).__init__(
+ key_in_channels=low_in_channels,
+ query_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=share_key_query,
+ query_downsample=query_downsample,
+ key_downsample=key_psp,
+ key_query_num_convs=1,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+
+class AFNB(nn.Module):
+ """Asymmetric Fusion Non-local Block(AFNB)
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ and query projection.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, query_scales, key_pool_scales, conv_cfg,
+ norm_cfg, act_cfg):
+ super(AFNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=False,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ out_channels + high_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, low_feats, high_feats):
+ """Forward function."""
+ priors = [stage(high_feats, low_feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, high_feats], 1))
+ return output
+
+
+class APNB(nn.Module):
+ """Asymmetric Pyramid Non-local Block (APNB)
+
+ Args:
+ in_channels (int): Input channels of key/query feature,
+ which is the key feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, channels, out_channels, query_scales,
+ key_pool_scales, conv_cfg, norm_cfg, act_cfg):
+ super(APNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=in_channels,
+ high_in_channels=in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=True,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ 2 * in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, feats):
+ """Forward function."""
+ priors = [stage(feats, feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, feats], 1))
+ return output
+
+
+@HEADS.register_module()
+class ANNHead(BaseDecodeHead):
+ """Asymmetric Non-local Neural Networks for Semantic Segmentation.
+
+ This head is the implementation of `ANNNet
+ `_.
+
+ Args:
+ project_channels (int): Projection channels for Nonlocal.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): The pooling scales of key feature map.
+ Default: (1, 3, 6, 8).
+ """
+
+ def __init__(self,
+ project_channels,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ **kwargs):
+ super(ANNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(self.in_channels) == 2
+ low_in_channels, high_in_channels = self.in_channels
+ self.project_channels = project_channels
+ self.fusion = AFNB(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ out_channels=high_in_channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ high_in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.context = APNB(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ low_feats, high_feats = self._transform_inputs(inputs)
+ output = self.fusion(low_feats, high_feats)
+ output = self.dropout(output)
+ output = self.bottleneck(output)
+ output = self.context(output)
+ output = self.cls_seg(output)
+
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/apc_head.py b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7038bdbe0edf2a1f184b6899486d2d190dda076
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ACM(nn.Module):
+ """Adaptive Context Module used in APCNet.
+
+ Args:
+ pool_scale (int): Pooling scale used in Adaptive Context
+ Module to extract region features.
+ fusion (bool): Add one conv to fuse residual feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(ACM, self).__init__()
+ self.pool_scale = pool_scale
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.pooled_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.global_info = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
+
+ self.residual_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
+ # [batch_size, channels, h, w]
+ x = self.input_redu_conv(x)
+ # [batch_size, channels, pool_scale, pool_scale]
+ pooled_x = self.pooled_redu_conv(pooled_x)
+ batch_size = x.size(0)
+ # [batch_size, pool_scale * pool_scale, channels]
+ pooled_x = pooled_x.view(batch_size, self.channels,
+ -1).permute(0, 2, 1).contiguous()
+ # [batch_size, h * w, pool_scale * pool_scale]
+ affinity_matrix = self.gla(x + resize(
+ self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
+ ).permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.pool_scale**2)
+ affinity_matrix = F.sigmoid(affinity_matrix)
+ # [batch_size, h * w, channels]
+ z_out = torch.matmul(affinity_matrix, pooled_x)
+ # [batch_size, channels, h * w]
+ z_out = z_out.permute(0, 2, 1).contiguous()
+ # [batch_size, channels, h, w]
+ z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
+ z_out = self.residual_conv(z_out)
+ z_out = F.relu(z_out + x)
+ if self.fusion:
+ z_out = self.fusion_conv(z_out)
+
+ return z_out
+
+
+@HEADS.register_module()
+class APCHead(BaseDecodeHead):
+ """Adaptive Pyramid Context Network for Semantic Segmentation.
+
+ This head is the implementation of
+ `APCNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Adaptive Context
+ Module. Default: (1, 2, 3, 6).
+ fusion (bool): Add one conv to fuse residual feature.
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
+ super(APCHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.fusion = fusion
+ acm_modules = []
+ for pool_scale in self.pool_scales:
+ acm_modules.append(
+ ACM(pool_scale,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.acm_modules = nn.ModuleList(acm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ acm_outs = [x]
+ for acm_module in self.acm_modules:
+ acm_outs.append(acm_module(x))
+ acm_outs = torch.cat(acm_outs, dim=1)
+ output = self.bottleneck(acm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa914b5bb25124d1ff199553d96713d6a80484c0
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ASPPModule(nn.ModuleList):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module.
+
+ Args:
+ dilations (tuple[int]): Dilation rate of each layer.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg):
+ super(ASPPModule, self).__init__()
+ self.dilations = dilations
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for dilation in dilations:
+ self.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1 if dilation == 1 else 3,
+ dilation=dilation,
+ padding=0 if dilation == 1 else dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, x):
+ """Forward function."""
+ aspp_outs = []
+ for aspp_module in self:
+ aspp_outs.append(aspp_module(x))
+
+ return aspp_outs
+
+
+@HEADS.register_module()
+class ASPPHead(BaseDecodeHead):
+ """Rethinking Atrous Convolution for Semantic Image Segmentation.
+
+ This head is the implementation of `DeepLabV3
+ `_.
+
+ Args:
+ dilations (tuple[int]): Dilation rates for ASPP module.
+ Default: (1, 6, 12, 18).
+ """
+
+ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
+ super(ASPPHead, self).__init__(**kwargs)
+ assert isinstance(dilations, (list, tuple))
+ self.dilations = dilations
+ self.image_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.aspp_modules = ASPPModule(
+ dilations,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ (len(dilations) + 1) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+
+from .decode_head import BaseDecodeHead
+
+
+class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
+ """Base class for cascade decode head used in
+ :class:`CascadeEncoderDecoder."""
+
+ def __init__(self, *args, **kwargs):
+ super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
+
+ @abstractmethod
+ def forward(self, inputs, prev_output):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs, prev_output)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, prev_output)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/cc_head.py b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9abb4e747f92657f4220b29788539340986c00
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
@@ -0,0 +1,42 @@
+import torch
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+try:
+ from annotator.uniformer.mmcv.ops import CrissCrossAttention
+except ModuleNotFoundError:
+ CrissCrossAttention = None
+
+
+@HEADS.register_module()
+class CCHead(FCNHead):
+ """CCNet: Criss-Cross Attention for Semantic Segmentation.
+
+ This head is the implementation of `CCNet
+ `_.
+
+ Args:
+ recurrence (int): Number of recurrence of Criss Cross Attention
+ module. Default: 2.
+ """
+
+ def __init__(self, recurrence=2, **kwargs):
+ if CrissCrossAttention is None:
+ raise RuntimeError('Please install mmcv-full for '
+ 'CrissCrossAttention ops')
+ super(CCHead, self).__init__(num_convs=2, **kwargs)
+ self.recurrence = recurrence
+ self.cca = CrissCrossAttention(self.channels)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ for _ in range(self.recurrence):
+ output = self.cca(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/da_head.py b/annotator/uniformer/mmseg/models/decode_heads/da_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd49fcfdc7c0a70f9485cc71843dcf3e0cb1774
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/da_head.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, Scale
+from torch import nn
+
+from annotator.uniformer.mmseg.core import add_prefix
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PAM(_SelfAttentionBlock):
+ """Position Attention Module (PAM)
+
+ Args:
+ in_channels (int): Input channels of key/query feature.
+ channels (int): Output channels of key/query transform.
+ """
+
+ def __init__(self, in_channels, channels):
+ super(PAM, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=None,
+ key_downsample=None,
+ key_query_num_convs=1,
+ key_query_norm=False,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=False,
+ with_out=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None)
+
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ out = super(PAM, self).forward(x, x)
+
+ out = self.gamma(out) + x
+ return out
+
+
+class CAM(nn.Module):
+ """Channel Attention Module (CAM)"""
+
+ def __init__(self):
+ super(CAM, self).__init__()
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ batch_size, channels, height, width = x.size()
+ proj_query = x.view(batch_size, channels, -1)
+ proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
+ energy = torch.bmm(proj_query, proj_key)
+ energy_new = torch.max(
+ energy, -1, keepdim=True)[0].expand_as(energy) - energy
+ attention = F.softmax(energy_new, dim=-1)
+ proj_value = x.view(batch_size, channels, -1)
+
+ out = torch.bmm(attention, proj_value)
+ out = out.view(batch_size, channels, height, width)
+
+ out = self.gamma(out) + x
+ return out
+
+
+@HEADS.register_module()
+class DAHead(BaseDecodeHead):
+ """Dual Attention Network for Scene Segmentation.
+
+ This head is the implementation of `DANet
+ `_.
+
+ Args:
+ pam_channels (int): The channels of Position Attention Module(PAM).
+ """
+
+ def __init__(self, pam_channels, **kwargs):
+ super(DAHead, self).__init__(**kwargs)
+ self.pam_channels = pam_channels
+ self.pam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam = PAM(self.channels, pam_channels)
+ self.pam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ self.cam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam = CAM()
+ self.cam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ def pam_cls_seg(self, feat):
+ """PAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.pam_conv_seg(feat)
+ return output
+
+ def cam_cls_seg(self, feat):
+ """CAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.cam_conv_seg(feat)
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ pam_feat = self.pam_in_conv(x)
+ pam_feat = self.pam(pam_feat)
+ pam_feat = self.pam_out_conv(pam_feat)
+ pam_out = self.pam_cls_seg(pam_feat)
+
+ cam_feat = self.cam_in_conv(x)
+ cam_feat = self.cam(cam_feat)
+ cam_feat = self.cam_out_conv(cam_feat)
+ cam_out = self.cam_cls_seg(cam_feat)
+
+ feat_sum = pam_feat + cam_feat
+ pam_cam_out = self.cls_seg(feat_sum)
+
+ return pam_cam_out, pam_out, cam_out
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, only ``pam_cam`` is used."""
+ return self.forward(inputs)[0]
+
+ def losses(self, seg_logit, seg_label):
+ """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
+ pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
+ loss = dict()
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
+ 'pam_cam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
+ return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a661b8f6fec5d4c031d3d85e80777ee63951a6
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
@@ -0,0 +1,234 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import normal_init
+from annotator.uniformer.mmcv.runner import auto_fp16, force_fp32
+
+from annotator.uniformer.mmseg.core import build_pixel_sampler
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import build_loss
+from ..losses import accuracy
+
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.loss_decode = build_loss(loss_decode)
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+ @force_fp32(apply_to=('seg_logit', ))
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ seg_logit = resize(
+ input=seg_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.sampler is not None:
+ seg_weight = self.sampler.sample(seg_logit, seg_label)
+ else:
+ seg_weight = None
+ seg_label = seg_label.squeeze(1)
+ loss['loss_seg'] = self.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=seg_weight,
+ ignore_index=self.ignore_index)
+ loss['acc_seg'] = accuracy(seg_logit, seg_label)
+ return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/dm_head.py b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c963923126b53ce22f60813540a35badf24b3d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class DCM(nn.Module):
+ """Dynamic Convolutional Module used in DMNet.
+
+ Args:
+ filter_size (int): The filter size of generated convolution kernel
+ used in Dynamic Convolutional Module.
+ fusion (bool): Add one conv to fuse DCM output feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(DCM, self).__init__()
+ self.filter_size = filter_size
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
+ 0)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.norm_cfg is not None:
+ self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
+ else:
+ self.norm = None
+ self.activate = build_activation_layer(self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ generated_filter = self.filter_gen_conv(
+ F.adaptive_avg_pool2d(x, self.filter_size))
+ x = self.input_redu_conv(x)
+ b, c, h, w = x.shape
+ # [1, b * c, h, w], c = self.channels
+ x = x.view(1, b * c, h, w)
+ # [b * c, 1, filter_size, filter_size]
+ generated_filter = generated_filter.view(b * c, 1, self.filter_size,
+ self.filter_size)
+ pad = (self.filter_size - 1) // 2
+ if (self.filter_size - 1) % 2 == 0:
+ p2d = (pad, pad, pad, pad)
+ else:
+ p2d = (pad + 1, pad, pad + 1, pad)
+ x = F.pad(input=x, pad=p2d, mode='constant', value=0)
+ # [1, b * c, h, w]
+ output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
+ # [b, c, h, w]
+ output = output.view(b, c, h, w)
+ if self.norm is not None:
+ output = self.norm(output)
+ output = self.activate(output)
+
+ if self.fusion:
+ output = self.fusion_conv(output)
+
+ return output
+
+
+@HEADS.register_module()
+class DMHead(BaseDecodeHead):
+ """Dynamic Multi-scale Filters for Semantic Segmentation.
+
+ This head is the implementation of
+ `DMNet `_.
+
+ Args:
+ filter_sizes (tuple[int]): The size of generated convolutional filters
+ used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
+ fusion (bool): Add one conv to fuse DCM output feature.
+ """
+
+ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
+ super(DMHead, self).__init__(**kwargs)
+ assert isinstance(filter_sizes, (list, tuple))
+ self.filter_sizes = filter_sizes
+ self.fusion = fusion
+ dcm_modules = []
+ for filter_size in self.filter_sizes:
+ dcm_modules.append(
+ DCM(filter_size,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.dcm_modules = nn.ModuleList(dcm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(filter_sizes) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ dcm_outs = [x]
+ for dcm_module in self.dcm_modules:
+ dcm_outs.append(dcm_module(x))
+ dcm_outs = torch.cat(dcm_outs, dim=1)
+ output = self.bottleneck(dcm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..333280c5947066fd3c7ebcfe302a0e7ad65480d5
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
@@ -0,0 +1,131 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+from torch import nn
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+class DisentangledNonLocal2d(NonLocal2d):
+ """Disentangled Non-Local Blocks.
+
+ Args:
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self, *arg, temperature, **kwargs):
+ super().__init__(*arg, **kwargs)
+ self.temperature = temperature
+ self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ """Embedded gaussian with temperature."""
+
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight /= self.temperature
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def forward(self, x):
+ # x: [N, C, H, W]
+ n = x.size(0)
+
+ # g_x: [N, HxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ # subtract mean
+ theta_x -= theta_x.mean(dim=-2, keepdim=True)
+ phi_x -= phi_x.mean(dim=-1, keepdim=True)
+
+ pairwise_func = getattr(self, self.mode)
+ # pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # y: [N, HxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # y: [N, C, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ # unary_mask: [N, 1, HxW]
+ unary_mask = self.conv_mask(x)
+ unary_mask = unary_mask.view(n, 1, -1)
+ unary_mask = unary_mask.softmax(dim=-1)
+ # unary_x: [N, 1, C]
+ unary_x = torch.matmul(unary_mask, g_x)
+ # unary_x: [N, C, 1, 1]
+ unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
+ n, self.inter_channels, 1, 1)
+
+ output = x + self.conv_out(y + unary_x)
+
+ return output
+
+
+@HEADS.register_module()
+class DNLHead(FCNHead):
+ """Disentangled Non-Local Neural Networks.
+
+ This head is the implementation of `DNLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: False.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ temperature=0.05,
+ **kwargs):
+ super(DNLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.temperature = temperature
+ self.dnl_block = DisentangledNonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode,
+ temperature=self.temperature)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.dnl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ema_head.py b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..12267cb40569d2b5a4a2955a6dc2671377ff5e0a
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
@@ -0,0 +1,168 @@
+import math
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+def reduce_mean(tensor):
+ """Reduce mean when distributed training."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
+
+
+class EMAModule(nn.Module):
+ """Expectation Maximization Attention Module used in EMANet.
+
+ Args:
+ channels (int): Channels of the whole module.
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ """
+
+ def __init__(self, channels, num_bases, num_stages, momentum):
+ super(EMAModule, self).__init__()
+ assert num_stages >= 1, 'num_stages must be at least 1!'
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.momentum = momentum
+
+ bases = torch.zeros(1, channels, self.num_bases)
+ bases.normal_(0, math.sqrt(2. / self.num_bases))
+ # [1, channels, num_bases]
+ bases = F.normalize(bases, dim=1, p=2)
+ self.register_buffer('bases', bases)
+
+ def forward(self, feats):
+ """Forward function."""
+ batch_size, channels, height, width = feats.size()
+ # [batch_size, channels, height*width]
+ feats = feats.view(batch_size, channels, height * width)
+ # [batch_size, channels, num_bases]
+ bases = self.bases.repeat(batch_size, 1, 1)
+
+ with torch.no_grad():
+ for i in range(self.num_stages):
+ # [batch_size, height*width, num_bases]
+ attention = torch.einsum('bcn,bck->bnk', feats, bases)
+ attention = F.softmax(attention, dim=2)
+ # l1 norm
+ attention_normed = F.normalize(attention, dim=1, p=1)
+ # [batch_size, channels, num_bases]
+ bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+
+ feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
+ feats_recon = feats_recon.view(batch_size, channels, height, width)
+
+ if self.training:
+ bases = bases.mean(dim=0, keepdim=True)
+ bases = reduce_mean(bases)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+ self.bases = (1 -
+ self.momentum) * self.bases + self.momentum * bases
+
+ return feats_recon
+
+
+@HEADS.register_module()
+class EMAHead(BaseDecodeHead):
+ """Expectation Maximization Attention Networks for Semantic Segmentation.
+
+ This head is the implementation of `EMANet
+ `_.
+
+ Args:
+ ema_channels (int): EMA module channels
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer. Default: True
+ momentum (float): Momentum to update the base. Default: 0.1.
+ """
+
+ def __init__(self,
+ ema_channels,
+ num_bases,
+ num_stages,
+ concat_input=True,
+ momentum=0.1,
+ **kwargs):
+ super(EMAHead, self).__init__(**kwargs)
+ self.ema_channels = ema_channels
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.concat_input = concat_input
+ self.momentum = momentum
+ self.ema_module = EMAModule(self.ema_channels, self.num_bases,
+ self.num_stages, self.momentum)
+
+ self.ema_in_conv = ConvModule(
+ self.in_channels,
+ self.ema_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # project (0, inf) -> (-inf, inf)
+ self.ema_mid_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
+ for param in self.ema_mid_conv.parameters():
+ param.requires_grad = False
+
+ self.ema_out_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.bottleneck = ConvModule(
+ self.ema_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.ema_in_conv(x)
+ identity = feats
+ feats = self.ema_mid_conv(feats)
+ recon = self.ema_module(feats)
+ recon = F.relu(recon, inplace=True)
+ recon = self.ema_out_conv(recon)
+ output = F.relu(identity + recon, inplace=True)
+ output = self.bottleneck(output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/enc_head.py b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..da57af617e05d41761628fd2d6d232655b32d905
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_norm_layer
+
+from annotator.uniformer.mmseg.ops import Encoding, resize
+from ..builder import HEADS, build_loss
+from .decode_head import BaseDecodeHead
+
+
+class EncModule(nn.Module):
+ """Encoding Module used in EncNet.
+
+ Args:
+ in_channels (int): Input channels.
+ num_codes (int): Number of code words.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
+ super(EncModule, self).__init__()
+ self.encoding_project = ConvModule(
+ in_channels,
+ in_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ # TODO: resolve this hack
+ # change to 1d
+ if norm_cfg is not None:
+ encoding_norm_cfg = norm_cfg.copy()
+ if encoding_norm_cfg['type'] in ['BN', 'IN']:
+ encoding_norm_cfg['type'] += '1d'
+ else:
+ encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
+ '2d', '1d')
+ else:
+ # fallback to BN1d
+ encoding_norm_cfg = dict(type='BN1d')
+ self.encoding = nn.Sequential(
+ Encoding(channels=in_channels, num_codes=num_codes),
+ build_norm_layer(encoding_norm_cfg, num_codes)[1],
+ nn.ReLU(inplace=True))
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels), nn.Sigmoid())
+
+ def forward(self, x):
+ """Forward function."""
+ encoding_projection = self.encoding_project(x)
+ encoding_feat = self.encoding(encoding_projection).mean(dim=1)
+ batch_size, channels, _, _ = x.size()
+ gamma = self.fc(encoding_feat)
+ y = gamma.view(batch_size, channels, 1, 1)
+ output = F.relu_(x + x * y)
+ return encoding_feat, output
+
+
+@HEADS.register_module()
+class EncHead(BaseDecodeHead):
+ """Context Encoding for Semantic Segmentation.
+
+ This head is the implementation of `EncNet
+ `_.
+
+ Args:
+ num_codes (int): Number of code words. Default: 32.
+ use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
+ regularize the training. Default: True.
+ add_lateral (bool): Whether use lateral connection to fuse features.
+ Default: False.
+ loss_se_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
+ """
+
+ def __init__(self,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ loss_se_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=0.2),
+ **kwargs):
+ super(EncHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ self.use_se_loss = use_se_loss
+ self.add_lateral = add_lateral
+ self.num_codes = num_codes
+ self.bottleneck = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if add_lateral:
+ self.lateral_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the last one
+ self.lateral_convs.append(
+ ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.fusion = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.enc_module = EncModule(
+ self.channels,
+ num_codes=num_codes,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.use_se_loss:
+ self.loss_se_decode = build_loss(loss_se_decode)
+ self.se_layer = nn.Linear(self.channels, self.num_classes)
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ feat = self.bottleneck(inputs[-1])
+ if self.add_lateral:
+ laterals = [
+ resize(
+ lateral_conv(inputs[i]),
+ size=feat.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ feat = self.fusion(torch.cat([feat, *laterals], 1))
+ encode_feat, output = self.enc_module(feat)
+ output = self.cls_seg(output)
+ if self.use_se_loss:
+ se_output = self.se_layer(encode_feat)
+ return output, se_output
+ else:
+ return output
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, ignore se_loss."""
+ if self.use_se_loss:
+ return self.forward(inputs)[0]
+ else:
+ return self.forward(inputs)
+
+ @staticmethod
+ def _convert_to_onehot_labels(seg_label, num_classes):
+ """Convert segmentation label to onehot.
+
+ Args:
+ seg_label (Tensor): Segmentation label of shape (N, H, W).
+ num_classes (int): Number of classes.
+
+ Returns:
+ Tensor: Onehot labels of shape (N, num_classes).
+ """
+
+ batch_size = seg_label.size(0)
+ onehot_labels = seg_label.new_zeros((batch_size, num_classes))
+ for i in range(batch_size):
+ hist = seg_label[i].float().histc(
+ bins=num_classes, min=0, max=num_classes - 1)
+ onehot_labels[i] = hist > 0
+ return onehot_labels
+
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation and semantic encoding loss."""
+ seg_logit, se_seg_logit = seg_logit
+ loss = dict()
+ loss.update(super(EncHead, self).losses(seg_logit, seg_label))
+ se_loss = self.loss_se_decode(
+ se_seg_logit,
+ self._convert_to_onehot_labels(seg_label, self.num_classes))
+ loss['loss_se'] = se_loss
+ return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb32c283fa4baada6b4a0bf3f7540c3580c3468
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FCNHead(BaseDecodeHead):
+ """Fully Convolution Networks for Semantic Segmentation.
+
+ This head is implemented of `FCNNet `_.
+
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+ """
+
+ def __init__(self,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ dilation=1,
+ **kwargs):
+ assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ super(FCNHead, self).__init__(**kwargs)
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for i in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs(x)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1241c55b0813d1ecdddf1e66e7c5031fbf78ed50
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
@@ -0,0 +1,68 @@
+import numpy as np
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FPNHead(BaseDecodeHead):
+ """Panoptic Feature Pyramid Networks.
+
+ This head is the implementation of `Semantic FPN
+ `_.
+
+ Args:
+ feature_strides (tuple[int]): The strides for input feature maps.
+ stack_lateral. All strides suppose to be power of 2. The first
+ one is of largest resolution.
+ """
+
+ def __init__(self, feature_strides, **kwargs):
+ super(FPNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+
+ self.scale_heads = nn.ModuleList()
+ for i in range(len(feature_strides)):
+ head_length = max(
+ 1,
+ int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
+ scale_head = []
+ for k in range(head_length):
+ scale_head.append(
+ ConvModule(
+ self.in_channels[i] if k == 0 else self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if feature_strides[i] != feature_strides[0]:
+ scale_head.append(
+ nn.Upsample(
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=self.align_corners))
+ self.scale_heads.append(nn.Sequential(*scale_head))
+
+ def forward(self, inputs):
+
+ x = self._transform_inputs(inputs)
+
+ output = self.scale_heads[0](x[0])
+ for i in range(1, len(self.feature_strides)):
+ # non inplace
+ output = output + resize(
+ self.scale_heads[i](x[i]),
+ size=output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/gc_head.py b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..70741245af975800840709911bd18d72247e3e04
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
@@ -0,0 +1,47 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ContextBlock
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class GCHead(FCNHead):
+ """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
+
+ This head is the implementation of `GCNet
+ `_.
+
+ Args:
+ ratio (float): Multiplier of channels ratio. Default: 1/4.
+ pooling_type (str): The pooling type of context aggregation.
+ Options are 'att', 'avg'. Default: 'avg'.
+ fusion_types (tuple[str]): The fusion type for feature fusion.
+ Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
+ """
+
+ def __init__(self,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ **kwargs):
+ super(GCHead, self).__init__(num_convs=2, **kwargs)
+ self.ratio = ratio
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ self.gc_block = ContextBlock(
+ in_channels=self.channels,
+ ratio=self.ratio,
+ pooling_type=self.pooling_type,
+ fusion_types=self.fusion_types)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.gc_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bf320934d787aaa11984a0c4effe9ad8015b22
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv import is_tuple_of
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class LRASPPHead(BaseDecodeHead):
+ """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
+
+ This head is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ branch_channels (tuple[int]): The number of output channels in every
+ each branch. Default: (32, 64).
+ """
+
+ def __init__(self, branch_channels=(32, 64), **kwargs):
+ super(LRASPPHead, self).__init__(**kwargs)
+ if self.input_transform != 'multiple_select':
+ raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
+ f'must be \'multiple_select\'. But received '
+ f'\'{self.input_transform}\'')
+ assert is_tuple_of(branch_channels, int)
+ assert len(branch_channels) == len(self.in_channels) - 1
+ self.branch_channels = branch_channels
+
+ self.convs = nn.Sequential()
+ self.conv_ups = nn.Sequential()
+ for i in range(len(branch_channels)):
+ self.convs.add_module(
+ f'conv{i}',
+ nn.Conv2d(
+ self.in_channels[i], branch_channels[i], 1, bias=False))
+ self.conv_ups.add_module(
+ f'conv_up{i}',
+ ConvModule(
+ self.channels + branch_channels[i],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False))
+
+ self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
+
+ self.aspp_conv = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False)
+ self.image_pool = nn.Sequential(
+ nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
+ ConvModule(
+ self.in_channels[2],
+ self.channels,
+ 1,
+ act_cfg=dict(type='Sigmoid'),
+ bias=False))
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+
+ x = inputs[-1]
+
+ x = self.aspp_conv(x) * resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = self.conv_up_input(x)
+
+ for i in range(len(self.branch_channels) - 1, -1, -1):
+ x = resize(
+ x,
+ size=inputs[i].size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = torch.cat([x, self.convs[i](inputs[i])], 1)
+ x = self.conv_ups[i](x)
+
+ return self.cls_seg(x)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/nl_head.py b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eee424199e6aa363b564e2a3340a070db04db86
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
@@ -0,0 +1,49 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class NLHead(FCNHead):
+ """Non-local Neural Networks.
+
+ This head is the implementation of `NLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: True.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(NLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.nl_block = NonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.nl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..715852e94e81dc46623972748285d2d19237a341
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+class SpatialGatherModule(nn.Module):
+ """Aggregate the context features according to the initial predicted
+ probability distribution.
+
+ Employ the soft-weighted method to aggregate the context.
+ """
+
+ def __init__(self, scale):
+ super(SpatialGatherModule, self).__init__()
+ self.scale = scale
+
+ def forward(self, feats, probs):
+ """Forward function."""
+ batch_size, num_classes, height, width = probs.size()
+ channels = feats.size(1)
+ probs = probs.view(batch_size, num_classes, -1)
+ feats = feats.view(batch_size, channels, -1)
+ # [batch_size, height*width, num_classes]
+ feats = feats.permute(0, 2, 1)
+ # [batch_size, channels, height*width]
+ probs = F.softmax(self.scale * probs, dim=2)
+ # [batch_size, channels, num_classes]
+ ocr_context = torch.matmul(probs, feats)
+ ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
+ return ocr_context
+
+
+class ObjectAttentionBlock(_SelfAttentionBlock):
+ """Make a OCR used SelfAttentionBlock."""
+
+ def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
+ act_cfg):
+ if scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=scale)
+ else:
+ query_downsample = None
+ super(ObjectAttentionBlock, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=query_downsample,
+ key_downsample=None,
+ key_query_num_convs=2,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=True,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.bottleneck = ConvModule(
+ in_channels * 2,
+ in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ context = super(ObjectAttentionBlock,
+ self).forward(query_feats, key_feats)
+ output = self.bottleneck(torch.cat([context, query_feats], dim=1))
+ if self.query_downsample is not None:
+ output = resize(query_feats)
+
+ return output
+
+
+@HEADS.register_module()
+class OCRHead(BaseCascadeDecodeHead):
+ """Object-Contextual Representations for Semantic Segmentation.
+
+ This head is the implementation of `OCRNet
+ `_.
+
+ Args:
+ ocr_channels (int): The intermediate channels of OCR block.
+ scale (int): The scale of probability map in SpatialGatherModule in
+ Default: 1.
+ """
+
+ def __init__(self, ocr_channels, scale=1, **kwargs):
+ super(OCRHead, self).__init__(**kwargs)
+ self.ocr_channels = ocr_channels
+ self.scale = scale
+ self.object_context_block = ObjectAttentionBlock(
+ self.channels,
+ self.ocr_channels,
+ self.scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.spatial_gather_module = SpatialGatherModule(self.scale)
+
+ self.bottleneck = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs, prev_output):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.bottleneck(x)
+ context = self.spatial_gather_module(feats, prev_output)
+ object_context = self.object_context_block(feats, context)
+ output = self.cls_seg(object_context)
+
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/point_head.py b/annotator/uniformer/mmseg/models/decode_heads/point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3342aa28bb8d264b2c3d01cbf5098d145943c193
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/point_head.py
@@ -0,0 +1,349 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, normal_init
+from annotator.uniformer.mmcv.ops import point_sample
+
+from annotator.uniformer.mmseg.models.builder import HEADS
+from annotator.uniformer.mmseg.ops import resize
+from ..losses import accuracy
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+def calculate_uncertainty(seg_logits):
+ """Estimate uncertainty based on seg logits.
+
+ For each location of the prediction ``seg_logits`` we estimate
+ uncertainty as the difference between top first and top second
+ predicted logits.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits,
+ shape (batch_size, num_classes, height, width).
+
+ Returns:
+ scores (Tensor): T uncertainty scores with the most uncertain
+ locations having the highest uncertainty score, shape (
+ batch_size, 1, height, width)
+ """
+ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
+ return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+
+
+@HEADS.register_module()
+class PointHead(BaseCascadeDecodeHead):
+ """A mask point head use in PointRend.
+
+ ``PointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict|None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict|None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU', inplace=False),
+ **kwargs):
+ super(PointHead, self).__init__(
+ input_transform='multiple_select',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+
+ self.num_fcs = num_fcs
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+
+ fc_in_channels = sum(self.in_channels) + self.num_classes
+ fc_channels = self.channels
+ self.fcs = nn.ModuleList()
+ for k in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
+ else 0
+ self.fc_seg = nn.Conv1d(
+ fc_in_channels,
+ self.num_classes,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ if self.dropout_ratio > 0:
+ self.dropout = nn.Dropout(self.dropout_ratio)
+ delattr(self, 'conv_seg')
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.fc_seg, std=0.001)
+
+ def cls_seg(self, feat):
+ """Classify each pixel with fc."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.fc_seg(feat)
+ return output
+
+ def forward(self, fine_grained_point_feats, coarse_point_feats):
+ x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_point_feats), dim=1)
+ return self.cls_seg(x)
+
+ def _get_fine_grained_point_feats(self, x, points):
+ """Sample from fine grained features.
+
+ Args:
+ x (list[Tensor]): Feature pyramid from by neck or backbone.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ fine_grained_feats (Tensor): Sampled fine grained feature,
+ shape (batch_size, sum(channels of x), num_points).
+ """
+
+ fine_grained_feats_list = [
+ point_sample(_, points, align_corners=self.align_corners)
+ for _ in x
+ ]
+ if len(fine_grained_feats_list) > 1:
+ fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
+ else:
+ fine_grained_feats = fine_grained_feats_list[0]
+
+ return fine_grained_feats
+
+ def _get_coarse_point_feats(self, prev_output, points):
+ """Sample from fine grained features.
+
+ Args:
+ prev_output (list[Tensor]): Prediction of previous decode head.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
+ num_classes, num_points).
+ """
+
+ coarse_feats = point_sample(
+ prev_output, points, align_corners=self.align_corners)
+
+ return coarse_feats
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self._transform_inputs(inputs)
+ with torch.no_grad():
+ points = self.get_points_train(
+ prev_output, calculate_uncertainty, cfg=train_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+ point_label = point_sample(
+ gt_semantic_seg.float(),
+ points,
+ mode='nearest',
+ align_corners=self.align_corners)
+ point_label = point_label.squeeze(1).long()
+
+ losses = self.losses(point_logits, point_label)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+
+ x = self._transform_inputs(inputs)
+ refined_seg_logits = prev_output.clone()
+ for _ in range(test_cfg.subdivision_steps):
+ refined_seg_logits = resize(
+ refined_seg_logits,
+ scale_factor=test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ batch_size, channels, height, width = refined_seg_logits.shape
+ point_indices, points = self.get_points_test(
+ refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(
+ prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_seg_logits = refined_seg_logits.reshape(
+ batch_size, channels, height * width)
+ refined_seg_logits = refined_seg_logits.scatter_(
+ 2, point_indices, point_logits)
+ refined_seg_logits = refined_seg_logits.view(
+ batch_size, channels, height, width)
+
+ return refined_seg_logits
+
+ def losses(self, point_logits, point_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ loss['loss_point'] = self.loss_decode(
+ point_logits, point_label, ignore_index=self.ignore_index)
+ loss['acc_point'] = accuracy(point_logits, point_label)
+ return loss
+
+ def get_points_train(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for training.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'uncertainty_func' function that takes point's logit prediction as
+ input.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits, shape (
+ batch_size, num_classes, height, width).
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Training config of point head.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains the coordinates of ``num_points`` sampled
+ points.
+ """
+ num_points = cfg.num_points
+ oversample_ratio = cfg.oversample_ratio
+ importance_sample_ratio = cfg.importance_sample_ratio
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = seg_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=seg_logits.device)
+ point_logits = point_sample(seg_logits, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=seg_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_point_coords = torch.rand(
+ batch_size, num_random_points, 2, device=seg_logits.device)
+ point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
+ return point_coords
+
+ def get_points_test(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for testing.
+
+ Find ``num_points`` most uncertain points from ``uncertainty_map``.
+
+ Args:
+ seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
+ height, width) for class-specific or class-agnostic prediction.
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Testing config of point head.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (batch_size, num_points)
+ that contains indices from [0, height x width) of the most
+ uncertain points.
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the ``height x width`` grid .
+ """
+
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = uncertainty_func(seg_logits)
+ batch_size, _, height, width = uncertainty_map.shape
+ h_step = 1.0 / height
+ w_step = 1.0 / width
+
+ uncertainty_map = uncertainty_map.view(batch_size, height * width)
+ num_points = min(height * width, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ point_coords = torch.zeros(
+ batch_size,
+ num_points,
+ 2,
+ dtype=torch.float,
+ device=seg_logits.device)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+ width).float() * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+ width).float() * h_step
+ return point_indices, point_coords
diff --git a/annotator/uniformer/mmseg/models/decode_heads/psa_head.py b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..480dbd1a081262e45bf87e32c4a339ac8f8b4ffb
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+try:
+ from annotator.uniformer.mmcv.ops import PSAMask
+except ModuleNotFoundError:
+ PSAMask = None
+
+
+@HEADS.register_module()
+class PSAHead(BaseDecodeHead):
+ """Point-wise Spatial Attention Network for Scene Parsing.
+
+ This head is the implementation of `PSANet
+ `_.
+
+ Args:
+ mask_size (tuple[int]): The PSA mask size. It usually equals input
+ size.
+ psa_type (str): The type of psa module. Options are 'collect',
+ 'distribute', 'bi-direction'. Default: 'bi-direction'
+ compact (bool): Whether use compact map for 'collect' mode.
+ Default: True.
+ shrink_factor (int): The downsample factors of psa mask. Default: 2.
+ normalization_factor (float): The normalize factor of attention.
+ psa_softmax (bool): Whether use softmax for attention.
+ """
+
+ def __init__(self,
+ mask_size,
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ **kwargs):
+ if PSAMask is None:
+ raise RuntimeError('Please install mmcv-full for PSAMask ops')
+ super(PSAHead, self).__init__(**kwargs)
+ assert psa_type in ['collect', 'distribute', 'bi-direction']
+ self.psa_type = psa_type
+ self.compact = compact
+ self.shrink_factor = shrink_factor
+ self.mask_size = mask_size
+ mask_h, mask_w = mask_size
+ self.psa_softmax = psa_softmax
+ if normalization_factor is None:
+ normalization_factor = mask_h * mask_w
+ self.normalization_factor = normalization_factor
+
+ self.reduce = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ if psa_type == 'bi-direction':
+ self.reduce_p = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention_p = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ self.psamask_collect = PSAMask('collect', mask_size)
+ self.psamask_distribute = PSAMask('distribute', mask_size)
+ else:
+ self.psamask = PSAMask(psa_type, mask_size)
+ self.proj = ConvModule(
+ self.channels * (2 if psa_type == 'bi-direction' else 1),
+ self.in_channels,
+ kernel_size=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ self.in_channels * 2,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ identity = x
+ align_corners = self.align_corners
+ if self.psa_type in ['collect', 'distribute']:
+ out = self.reduce(x)
+ n, c, h, w = out.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ out = resize(
+ out,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y = self.attention(out)
+ if self.compact:
+ if self.psa_type == 'collect':
+ y = y.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y = self.psamask(y)
+ if self.psa_softmax:
+ y = F.softmax(y, dim=1)
+ out = torch.bmm(
+ out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ else:
+ x_col = self.reduce(x)
+ x_dis = self.reduce_p(x)
+ n, c, h, w = x_col.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ x_col = resize(
+ x_col,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ x_dis = resize(
+ x_dis,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y_col = self.attention(x_col)
+ y_dis = self.attention_p(x_dis)
+ if self.compact:
+ y_dis = y_dis.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y_col = self.psamask_collect(y_col)
+ y_dis = self.psamask_distribute(y_dis)
+ if self.psa_softmax:
+ y_col = F.softmax(y_col, dim=1)
+ y_dis = F.softmax(y_dis, dim=1)
+ x_col = torch.bmm(
+ x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ x_dis = torch.bmm(
+ x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ out = torch.cat([x_col, x_dis], 1)
+ out = self.proj(out)
+ out = resize(
+ out,
+ size=identity.shape[2:],
+ mode='bilinear',
+ align_corners=align_corners)
+ out = self.bottleneck(torch.cat((identity, out), dim=1))
+ out = self.cls_seg(out)
+ return out
diff --git a/annotator/uniformer/mmseg/models/decode_heads/psp_head.py b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5f1e71c70c3a20f4007c263ec471a87bb214a48
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class PPM(nn.ModuleList):
+ """Pooling Pyramid Module used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ """
+
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg, align_corners):
+ super(PPM, self).__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for pool_scale in pool_scales:
+ self.append(
+ nn.Sequential(
+ nn.AdaptiveAvgPool2d(pool_scale),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)))
+
+ def forward(self, x):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = resize(
+ ppm_out,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+@HEADS.register_module()
+class PSPHead(BaseDecodeHead):
+ """Pyramid Scene Parsing Network.
+
+ This head is the implementation of
+ `PSPNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(PSPHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.psp_modules = PPM(
+ self.pool_scales,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3339a7ac56e77dfc638e9bffb557d4699148686b
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .aspp_head import ASPPHead, ASPPModule
+
+
+class DepthwiseSeparableASPPModule(ASPPModule):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
+ conv."""
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
+ for i, dilation in enumerate(self.dilations):
+ if dilation > 1:
+ self[i] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ dilation=dilation,
+ padding=dilation,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+
+@HEADS.register_module()
+class DepthwiseSeparableASPPHead(ASPPHead):
+ """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+ Segmentation.
+
+ This head is the implementation of `DeepLabV3+
+ `_.
+
+ Args:
+ c1_in_channels (int): The input channels of c1 decoder. If is 0,
+ the no decoder will be used.
+ c1_channels (int): The intermediate channels of c1 decoder.
+ """
+
+ def __init__(self, c1_in_channels, c1_channels, **kwargs):
+ super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
+ assert c1_in_channels >= 0
+ self.aspp_modules = DepthwiseSeparableASPPModule(
+ dilations=self.dilations,
+ in_channels=self.in_channels,
+ channels=self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if c1_in_channels > 0:
+ self.c1_bottleneck = ConvModule(
+ c1_in_channels,
+ c1_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ else:
+ self.c1_bottleneck = None
+ self.sep_bottleneck = nn.Sequential(
+ DepthwiseSeparableConvModule(
+ self.channels + c1_channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ if self.c1_bottleneck is not None:
+ c1_output = self.c1_bottleneck(inputs[0])
+ output = resize(
+ input=output,
+ size=c1_output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ output = torch.cat([output, c1_output], dim=1)
+ output = self.sep_bottleneck(output)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0986143fa4f2bd36f5271354fe5f843f35b9e6f
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
@@ -0,0 +1,51 @@
+from annotator.uniformer.mmcv.cnn import DepthwiseSeparableConvModule
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class DepthwiseSeparableFCNHead(FCNHead):
+ """Depthwise-Separable Fully Convolutional Network for Semantic
+ Segmentation.
+
+ This head is implemented according to Fast-SCNN paper.
+ Args:
+ in_channels(int): Number of output channels of FFM.
+ channels(int): Number of middle-stage channels in the decode head.
+ concat_input(bool): Whether to concatenate original decode input into
+ the result of several consecutive convolution layers.
+ Default: True.
+ num_classes(int): Used to determine the dimension of
+ final prediction tensor.
+ in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
+ norm_cfg (dict | None): Config of norm layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_decode(dict): Config of loss type and some
+ relevant additional options.
+ """
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
+ self.convs[0] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+ for i in range(1, self.num_convs):
+ self.convs[i] = DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+
+ if self.concat_input:
+ self.conv_cat = DepthwiseSeparableConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/uper_head.py b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1301b706b0d83ed714bbdee8ee24693f150455
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+from .psp_head import PPM
+
+
+@HEADS.register_module()
+class UPerHead(BaseDecodeHead):
+ """Unified Perceptual Parsing for Scene Understanding.
+
+ This head is the implementation of `UPerNet
+ `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module applied on the last feature. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(UPerHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ # PSP Module
+ self.psp_modules = PPM(
+ pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels[-1] + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def psp_forward(self, inputs):
+ """Forward function of PSP module."""
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+
+ inputs = self._transform_inputs(inputs)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ laterals.append(self.psp_forward(inputs))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += resize(
+ laterals[i],
+ size=prev_shape,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ # build outputs
+ fpn_outs = [
+ self.fpn_convs[i](laterals[i])
+ for i in range(used_backbone_levels - 1)
+ ]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = resize(
+ fpn_outs[i],
+ size=fpn_outs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/annotator/uniformer/mmseg/models/losses/__init__.py b/annotator/uniformer/mmseg/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/__init__.py
@@ -0,0 +1,12 @@
+from .accuracy import Accuracy, accuracy
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .lovasz_loss import LovaszLoss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
+ 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
+]
diff --git a/annotator/uniformer/mmseg/models/losses/accuracy.py b/annotator/uniformer/mmseg/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+
+
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == target.ndim + 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ # transpose to shape (maxk, N, ...)
+ pred_label = pred_label.transpose(0, 1)
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / target.numel()))
+ return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+ """Accuracy calculation module."""
+
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=-100):
+ """The wrapper function for :func:`F.cross_entropy`"""
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=255):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored. Default: 255
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (
+ pred.dim() == 4 and label.dim() == 3), \
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+ 'H, W], label shape [N, H, W] are supported'
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
+ ignore_index)
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/annotator/uniformer/mmseg/models/losses/dice_loss.py b/annotator/uniformer/mmseg/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/dice_loss.py
@@ -0,0 +1,119 @@
+"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
+segmentron/solver/loss.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weighted_loss
+
+
+@weighted_loss
+def dice_loss(pred,
+ target,
+ valid_mask,
+ smooth=1,
+ exponent=2,
+ class_weight=None,
+ ignore_index=255):
+ assert pred.shape[0] == target.shape[0]
+ total_loss = 0
+ num_classes = pred.shape[1]
+ for i in range(num_classes):
+ if i != ignore_index:
+ dice_loss = binary_dice_loss(
+ pred[:, i],
+ target[..., i],
+ valid_mask=valid_mask,
+ smooth=smooth,
+ exponent=exponent)
+ if class_weight is not None:
+ dice_loss *= class_weight[i]
+ total_loss += dice_loss
+ return total_loss / num_classes
+
+
+@weighted_loss
+def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
+ assert pred.shape[0] == target.shape[0]
+ pred = pred.reshape(pred.shape[0], -1)
+ target = target.reshape(target.shape[0], -1)
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
+ den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+
+ return 1 - num / den
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+ """DiceLoss.
+
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+ Volumetric Medical Image Segmentation `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ smooth (float): A float number to smooth loss, and avoid NaN error.
+ Default: 1
+ exponent (float): An float number to calculate denominator
+ value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ """
+
+ def __init__(self,
+ smooth=1,
+ exponent=2,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0,
+ ignore_index=255,
+ **kwards):
+ super(DiceLoss, self).__init__()
+ self.smooth = smooth
+ self.exponent = exponent
+ self.reduction = reduction
+ self.class_weight = get_class_weight(class_weight)
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+
+ def forward(self,
+ pred,
+ target,
+ avg_factor=None,
+ reduction_override=None,
+ **kwards):
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = pred.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ pred = F.softmax(pred, dim=1)
+ num_classes = pred.shape[1]
+ one_hot_target = F.one_hot(
+ torch.clamp(target.long(), 0, num_classes - 1),
+ num_classes=num_classes)
+ valid_mask = (target != self.ignore_index).long()
+
+ loss = self.loss_weight * dice_loss(
+ pred,
+ one_hot_target,
+ valid_mask=valid_mask,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ smooth=self.smooth,
+ exponent=self.exponent,
+ class_weight=class_weight,
+ ignore_index=self.ignore_index)
+ return loss
diff --git a/annotator/uniformer/mmseg/models/losses/lovasz_loss.py b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6badb67f6d987b59fb07aa97caaaf89896e27a8d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
@@ -0,0 +1,303 @@
+"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
+ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
+Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
+
+import annotator.uniformer.mmcv as mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def lovasz_grad(gt_sorted):
+ """Computes gradient of the Lovasz extension w.r.t sorted errors.
+
+ See Alg. 1 in paper.
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1. - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def flatten_binary_logits(logits, labels, ignore_index=None):
+ """Flattens predictions in the batch (binary case) Remove labels equal to
+ 'ignore_index'."""
+ logits = logits.view(-1)
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return logits, labels
+ valid = (labels != ignore_index)
+ vlogits = logits[valid]
+ vlabels = labels[valid]
+ return vlogits, vlabels
+
+
+def flatten_probs(probs, labels, ignore_index=None):
+ """Flattens predictions in the batch."""
+ if probs.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probs.size()
+ probs = probs.view(B, 1, H, W)
+ B, C, H, W = probs.size()
+ probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return probs, labels
+ valid = (labels != ignore_index)
+ vprobs = probs[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobs, vlabels
+
+
+def lovasz_hinge_flat(logits, labels):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [P], logits at each prediction
+ (between -infty and +infty).
+ labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels.float() - 1.
+ errors = (1. - logits * signs)
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), grad)
+ return loss
+
+
+def lovasz_hinge(logits,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [B, H, W], logits at each pixel
+ (between -infty and +infty).
+ labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
+ classes (str | list[int], optional): Placeholder, to be consistent with
+ other loss. Default: None.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): Placeholder, to be consistent
+ with other loss. Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if per_image:
+ loss = [
+ lovasz_hinge_flat(*flatten_binary_logits(
+ logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
+ for logit, label in zip(logits, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_hinge_flat(
+ *flatten_binary_logits(logits, labels, ignore_index))
+ return loss
+
+
+def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [P, C], class probabilities at each prediction
+ (between 0 and 1).
+ labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if probs.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probs * 0.
+ C = probs.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+ for c in class_to_sum:
+ fg = (labels == c).float() # foreground for class c
+ if (classes == 'present' and fg.sum() == 0):
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probs[:, 0]
+ else:
+ class_pred = probs[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
+ if class_weight is not None:
+ loss *= class_weight[c]
+ losses.append(loss)
+ return torch.stack(losses).mean()
+
+
+def lovasz_softmax(probs,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [B, C, H, W], class probabilities at each
+ prediction (between 0 and 1).
+ labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
+ C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+
+ if per_image:
+ loss = [
+ lovasz_softmax_flat(
+ *flatten_probs(
+ prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ for prob, label in zip(probs, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_softmax_flat(
+ *flatten_probs(probs, labels, ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ return loss
+
+
+@LOSSES.register_module()
+class LovaszLoss(nn.Module):
+ """LovaszLoss.
+
+ This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
+ for the optimization of the intersection-over-union measure in neural
+ networks `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ loss_type='multi_class',
+ classes='present',
+ per_image=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(LovaszLoss, self).__init__()
+ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
+ 'binary' or 'multi_class'."
+
+ if loss_type == 'binary':
+ self.cls_criterion = lovasz_hinge
+ else:
+ self.cls_criterion = lovasz_softmax
+ assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
+ if not per_image:
+ assert reduction == 'none', "reduction should be 'none' when \
+ per_image is False."
+
+ self.classes = classes
+ self.per_image = per_image
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ # if multi-class loss, transform logits to probs
+ if self.cls_criterion == lovasz_softmax:
+ cls_score = F.softmax(cls_score, dim=1)
+
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ self.classes,
+ self.per_image,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/annotator/uniformer/mmseg/models/losses/utils.py b/annotator/uniformer/mmseg/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..85aec9f3045240c3de96a928324ae8f5c3aebe8b
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/utils.py
@@ -0,0 +1,121 @@
+import functools
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch.nn.functional as F
+
+
+def get_class_weight(class_weight):
+ """Get class weight for loss function.
+
+ Args:
+ class_weight (list[float] | str | None): If class_weight is a str,
+ take it as a file name and read from it.
+ """
+ if isinstance(class_weight, str):
+ # take it as a file path
+ if class_weight.endswith('.npy'):
+ class_weight = np.load(class_weight)
+ else:
+ # pkl, json or yaml
+ class_weight = mmcv.load(class_weight)
+
+ return class_weight
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ if weight.dim() > 1:
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
diff --git a/annotator/uniformer/mmseg/models/necks/__init__.py b/annotator/uniformer/mmseg/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/__init__.py
@@ -0,0 +1,4 @@
+from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
+
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/annotator/uniformer/mmseg/models/necks/fpn.py b/annotator/uniformer/mmseg/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53b2a69500f8c2edb835abc3ff0ccc2173d1fb1
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/fpn.py
@@ -0,0 +1,212 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, xavier_init
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(nn.Module):
+ """Feature Pyramid Network.
+
+ This is an implementation of - Feature Pyramid Networks for Object
+ Detection (https://arxiv.org/abs/1612.03144)
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+ on the original feature from the backbone. If True,
+ it is equivalent to `add_extra_convs='on_input'`. If False, it is
+ equivalent to set `add_extra_convs='on_output'`. Default to True.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # For compatibility with previous release
+ # TODO: deprecate `extra_convs_on_inputs`
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/necks/multilevel_neck.py b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..766144d8136326a1fab5906a153a0c0df69b6b60
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class MultiLevelNeck(nn.Module):
+ """MultiLevelNeck.
+
+ A neck structure connect vit backbone and decoder_heads.
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ scales (List[int]): Scale factors for each input feature map.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scales=[0.5, 1, 2, 4],
+ norm_cfg=None,
+ act_cfg=None):
+ super(MultiLevelNeck, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scales = scales
+ self.num_outs = len(scales)
+ self.lateral_convs = nn.ModuleList()
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.lateral_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ for _ in range(self.num_outs):
+ self.convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ print(inputs[0].shape)
+ inputs = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # for len(inputs) not equal to self.num_outs
+ if len(inputs) == 1:
+ inputs = [inputs[0] for _ in range(self.num_outs)]
+ outs = []
+ for i in range(self.num_outs):
+ x_resize = F.interpolate(
+ inputs[i], scale_factor=self.scales[i], mode='bilinear')
+ outs.append(self.convs[i](x_resize))
+ return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/segmentors/__init__.py b/annotator/uniformer/mmseg/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseSegmentor
+from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .encoder_decoder import EncoderDecoder
+
+__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
diff --git a/annotator/uniformer/mmseg/models/segmentors/base.py b/annotator/uniformer/mmseg/models/segmentors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..172fc63b736c4f13be1cd909433bc260760a1eaa
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/base.py
@@ -0,0 +1,273 @@
+import logging
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from annotator.uniformer.mmcv.runner import auto_fp16
+
+
+class BaseSegmentor(nn.Module):
+ """Base class for segmentors."""
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(BaseSegmentor, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the segmentor has neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the segmentor has auxiliary head"""
+ return hasattr(self,
+ 'auxiliary_head') and self.auxiliary_head is not None
+
+ @property
+ def with_decode_head(self):
+ """bool: whether the segmentor has decode head"""
+ return hasattr(self, 'decode_head') and self.decode_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic segmentation map of the same size as input."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in segmentor.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info(f'load model from: {pretrained}')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got '
+ f'{type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) != '
+ f'num of image meta ({len(img_metas)})')
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_['ori_shape'] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_['img_shape'] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_['pad_shape'] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data_batch['img_metas']))
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def show_result(self,
+ img,
+ result,
+ palette=None,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ opacity=0.5):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor): The semantic segmentation results to draw over
+ `img`.
+ palette (list[list[int]]] | np.ndarray | None): The palette of
+ segmentation map. If None is given, random palette will be
+ generated. Default: None
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ seg = result[0]
+ if palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(
+ 0, 255, size=(len(self.CLASSES), 3))
+ else:
+ palette = self.PALETTE
+ palette = np.array(palette)
+ assert palette.shape[0] == len(self.CLASSES)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ assert 0 < opacity <= 1.0
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+
+ img = img * (1 - opacity) + color_seg * opacity
+ img = img.astype(np.uint8)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
diff --git a/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..873957d8d6468147c994493d92ff5c1b15bfb703
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
@@ -0,0 +1,98 @@
+from torch import nn
+
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .encoder_decoder import EncoderDecoder
+
+
+@SEGMENTORS.register_module()
+class CascadeEncoderDecoder(EncoderDecoder):
+ """Cascade Encoder Decoder segmentors.
+
+ CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
+ CascadeEncoderDecoder are cascaded. The output of previous decoder_head
+ will be the input of next decoder_head.
+ """
+
+ def __init__(self,
+ num_stages,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ self.num_stages = num_stages
+ super(CascadeEncoderDecoder, self).__init__(
+ backbone=backbone,
+ decode_head=decode_head,
+ neck=neck,
+ auxiliary_head=auxiliary_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ assert isinstance(decode_head, list)
+ assert len(decode_head) == self.num_stages
+ self.decode_head = nn.ModuleList()
+ for i in range(self.num_stages):
+ self.decode_head.append(builder.build_head(decode_head[i]))
+ self.align_corners = self.decode_head[-1].align_corners
+ self.num_classes = self.decode_head[-1].num_classes
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ self.backbone.init_weights(pretrained=pretrained)
+ for i in range(self.num_stages):
+ self.decode_head[i].init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
+ for i in range(1, self.num_stages):
+ out = self.decode_head[i].forward_test(x, out, img_metas,
+ self.test_cfg)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+
+ loss_decode = self.decode_head[0].forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode_0'))
+
+ for i in range(1, self.num_stages):
+ # forward test again, maybe unnecessary for most methods.
+ prev_outputs = self.decode_head[i - 1].forward_test(
+ x, img_metas, self.test_cfg)
+ loss_decode = self.decode_head[i].forward_train(
+ x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_decode, f'decode_{i}'))
+
+ return losses
diff --git a/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..98392ac04c4c44a7f4e7b1c0808266875877dd1f
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
@@ -0,0 +1,298 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .base import BaseSegmentor
+
+
+@SEGMENTORS.register_module()
+class EncoderDecoder(BaseSegmentor):
+ """Encoder Decoder segmentors.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+
+ def __init__(self,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(EncoderDecoder, self).__init__()
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.init_weights(pretrained=pretrained)
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+
+ def _init_auxiliary_head(self, auxiliary_head):
+ """Initialize ``auxiliary_head``"""
+ if auxiliary_head is not None:
+ if isinstance(auxiliary_head, list):
+ self.auxiliary_head = nn.ModuleList()
+ for head_cfg in auxiliary_head:
+ self.auxiliary_head.append(builder.build_head(head_cfg))
+ else:
+ self.auxiliary_head = builder.build_head(auxiliary_head)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ super(EncoderDecoder, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ self.decode_head.init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return seg_logits
+
+ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+ losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+ else:
+ loss_aux = self.auxiliary_head.forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+
+ return losses
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ seg_logit = self.encode_decode(img, None)
+
+ return seg_logit
+
+ def forward_train(self, img, img_metas, gt_semantic_seg):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(x, img_metas,
+ gt_semantic_seg)
+ losses.update(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(
+ x, img_metas, gt_semantic_seg)
+ losses.update(loss_aux)
+
+ return losses
+
+ # TODO refactor
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ num_classes = self.num_classes
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ crop_seg_logit = self.encode_decode(crop_img, img_meta)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(
+ count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ if rescale:
+ preds = resize(
+ preds,
+ size=img_meta[0]['ori_shape'][:2],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+ return preds
+
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ # support dynamic shape for onnx
+ if torch.onnx.is_in_onnx_export():
+ size = img.shape[2:]
+ else:
+ size = img_meta[0]['ori_shape'][:2]
+ seg_logit = resize(
+ seg_logit,
+ size=size,
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+
+ return seg_logit
+
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output segmentation map.
+ """
+
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = img_meta[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]['flip']
+ if flip:
+ flip_direction = img_meta[0]['flip_direction']
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ output = output.flip(dims=(3, ))
+ elif flip_direction == 'vertical':
+ output = output.flip(dims=(2, ))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ seg_logit = self.inference(img, img_meta, rescale)
+ seg_pred = seg_logit.argmax(dim=1)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ seg_pred = seg_pred.unsqueeze(0)
+ return seg_pred
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(imgs)
+ seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/annotator/uniformer/mmseg/models/utils/__init__.py b/annotator/uniformer/mmseg/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/__init__.py
@@ -0,0 +1,13 @@
+from .drop import DropPath
+from .inverted_residual import InvertedResidual, InvertedResidualV3
+from .make_divisible import make_divisible
+from .res_layer import ResLayer
+from .se_layer import SELayer
+from .self_attention_block import SelfAttentionBlock
+from .up_conv_block import UpConvBlock
+from .weight_init import trunc_normal_
+
+__all__ = [
+ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
+ 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
+]
diff --git a/annotator/uniformer/mmseg/models/utils/drop.py b/annotator/uniformer/mmseg/models/utils/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/drop.py
@@ -0,0 +1,31 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import torch
+from torch import nn
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ Args:
+ drop_prob (float): Drop rate for paths of model. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self, drop_prob=0.):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.keep_prob = 1 - drop_prob
+
+ def forward(self, x):
+ if self.drop_prob == 0. or not self.training:
+ return x
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = self.keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(self.keep_prob) * random_tensor
+ return output
diff --git a/annotator/uniformer/mmseg/models/utils/inverted_residual.py b/annotator/uniformer/mmseg/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b8fcd41f71d814738f1ac3f5acd3c3d701bf96
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/inverted_residual.py
@@ -0,0 +1,208 @@
+from annotator.uniformer.mmcv.cnn import ConvModule
+from torch import nn
+from torch.utils import checkpoint as cp
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(nn.Module):
+ """InvertedResidual block for MobileNetV2.
+
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ dilation (int): Dilation rate of depthwise conv. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ dilation=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InvertedResidualV3(nn.Module):
+ """Inverted Residual Block for MobileNetV3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels. Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False):
+ super(InvertedResidualV3, self).__init__()
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=mid_channels,
+ conv_cfg=dict(
+ type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+
+ out = self.depthwise_conv(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.linear_conv(out)
+
+ if self.with_res_shortcut:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
diff --git a/annotator/uniformer/mmseg/models/utils/make_divisible.py b/annotator/uniformer/mmseg/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/make_divisible.py
@@ -0,0 +1,27 @@
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+
+ This function rounds the channel number to the nearest value that can be
+ divisible by the divisor. It is taken from the original tf repo. It ensures
+ that all layers have a channel number that is divisible by divisor. It can
+ be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
+
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float): The minimum ratio of the rounded channel number to
+ the original channel number. Default: 0.9.
+
+ Returns:
+ int: The modified output channel number.
+ """
+
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/annotator/uniformer/mmseg/models/utils/res_layer.py b/annotator/uniformer/mmseg/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c07b47007e92e4c3945b989e79f9d50306f5fe
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/res_layer.py
@@ -0,0 +1,94 @@
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ multi_grid (int | None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ dilation=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ multi_grid=None,
+ contract_dilation=False,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if multi_grid is None:
+ if dilation > 1 and contract_dilation:
+ first_dilation = dilation // 2
+ else:
+ first_dilation = dilation
+ else:
+ first_dilation = multi_grid[0]
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ dilation=first_dilation,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ dilation=dilation if multi_grid is None else multi_grid[i],
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
diff --git a/annotator/uniformer/mmseg/models/utils/se_layer.py b/annotator/uniformer/mmseg/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..083bd7d1ccee909c900c7aed2cc928bf14727f3e
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/se_layer.py
@@ -0,0 +1,57 @@
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from .make_divisible import make_divisible
+
+
+class SELayer(nn.Module):
+ """Squeeze-and-Excitation Module.
+
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configured
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configured by the first dict and the
+ second activation layer will be configured by the second dict.
+ Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+ divisor=6.0)).
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0))):
+ super(SELayer, self).__init__()
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=make_divisible(channels // ratio, 8),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=make_divisible(channels // ratio, 8),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
diff --git a/annotator/uniformer/mmseg/models/utils/self_attention_block.py b/annotator/uniformer/mmseg/models/utils/self_attention_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..440c7b73ee4706fde555595926d63a18d7574acc
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/self_attention_block.py
@@ -0,0 +1,159 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class SelfAttentionBlock(nn.Module):
+ """General self-attention block/non-local block.
+
+ Please refer to https://arxiv.org/abs/1706.03762 for details about key,
+ query and value.
+
+ Args:
+ key_in_channels (int): Input channels of key feature.
+ query_in_channels (int): Input channels of query feature.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_downsample (nn.Module): Query downsample module.
+ key_downsample (nn.Module): Key downsample module.
+ key_query_num_convs (int): Number of convs for key/query projection.
+ value_num_convs (int): Number of convs for value projection.
+ matmul_norm (bool): Whether normalize attention map with sqrt of
+ channels
+ with_out (bool): Whether use out projection.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, key_in_channels, query_in_channels, channels,
+ out_channels, share_key_query, query_downsample,
+ key_downsample, key_query_num_convs, value_out_num_convs,
+ key_query_norm, value_out_norm, matmul_norm, with_out,
+ conv_cfg, norm_cfg, act_cfg):
+ super(SelfAttentionBlock, self).__init__()
+ if share_key_query:
+ assert key_in_channels == query_in_channels
+ self.key_in_channels = key_in_channels
+ self.query_in_channels = query_in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.share_key_query = share_key_query
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.key_project = self.build_project(
+ key_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if share_key_query:
+ self.query_project = self.key_project
+ else:
+ self.query_project = self.build_project(
+ query_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.value_project = self.build_project(
+ key_in_channels,
+ channels if with_out else out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if with_out:
+ self.out_project = self.build_project(
+ channels,
+ out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.out_project = None
+
+ self.query_downsample = query_downsample
+ self.key_downsample = key_downsample
+ self.matmul_norm = matmul_norm
+
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize weight of later layer."""
+ if self.out_project is not None:
+ if not isinstance(self.out_project, ConvModule):
+ constant_init(self.out_project, 0)
+
+ def build_project(self, in_channels, channels, num_convs, use_conv_module,
+ conv_cfg, norm_cfg, act_cfg):
+ """Build projection layer for key/query/value/out."""
+ if use_conv_module:
+ convs = [
+ ConvModule(
+ in_channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ ]
+ for _ in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ else:
+ convs = [nn.Conv2d(in_channels, channels, 1)]
+ for _ in range(num_convs - 1):
+ convs.append(nn.Conv2d(channels, channels, 1))
+ if len(convs) > 1:
+ convs = nn.Sequential(*convs)
+ else:
+ convs = convs[0]
+ return convs
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ batch_size = query_feats.size(0)
+ query = self.query_project(query_feats)
+ if self.query_downsample is not None:
+ query = self.query_downsample(query)
+ query = query.reshape(*query.shape[:2], -1)
+ query = query.permute(0, 2, 1).contiguous()
+
+ key = self.key_project(key_feats)
+ value = self.value_project(key_feats)
+ if self.key_downsample is not None:
+ key = self.key_downsample(key)
+ value = self.key_downsample(value)
+ key = key.reshape(*key.shape[:2], -1)
+ value = value.reshape(*value.shape[:2], -1)
+ value = value.permute(0, 2, 1).contiguous()
+
+ sim_map = torch.matmul(query, key)
+ if self.matmul_norm:
+ sim_map = (self.channels**-.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.reshape(batch_size, -1, *query_feats.shape[2:])
+ if self.out_project is not None:
+ context = self.out_project(context)
+ return context
diff --git a/annotator/uniformer/mmseg/models/utils/up_conv_block.py b/annotator/uniformer/mmseg/models/utils/up_conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..378469da76cb7bff6a639e7877b3c275d50490fb
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/up_conv_block.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, build_upsample_layer
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super(UpConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
diff --git a/annotator/uniformer/mmseg/models/utils/weight_init.py b/annotator/uniformer/mmseg/models/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/weight_init.py
@@ -0,0 +1,62 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import math
+import warnings
+
+import torch
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
+ /truncated_normal.pdf"""
+
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower_bound = norm_cdf((a - mean) / std)
+ upper_bound = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
+ mean (float): the mean of the normal distribution
+ std (float): the standard deviation of the normal distribution
+ a (float): the minimum cutoff value
+ b (float): the maximum cutoff value
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/annotator/uniformer/mmseg/ops/__init__.py b/annotator/uniformer/mmseg/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/__init__.py
@@ -0,0 +1,4 @@
+from .encoding import Encoding
+from .wrappers import Upsample, resize
+
+__all__ = ['Upsample', 'resize', 'Encoding']
diff --git a/annotator/uniformer/mmseg/ops/encoding.py b/annotator/uniformer/mmseg/ops/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/encoding.py
@@ -0,0 +1,74 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Encoding(nn.Module):
+ """Encoding Layer: a learnable residual encoder.
+
+ Input is of shape (batch_size, channels, height, width).
+ Output is of shape (batch_size, num_codes, channels).
+
+ Args:
+ channels: dimension of the features or feature channels
+ num_codes: number of code words
+ """
+
+ def __init__(self, channels, num_codes):
+ super(Encoding, self).__init__()
+ # init codewords and smoothing factor
+ self.channels, self.num_codes = channels, num_codes
+ std = 1. / ((num_codes * channels)**0.5)
+ # [num_codes, channels]
+ self.codewords = nn.Parameter(
+ torch.empty(num_codes, channels,
+ dtype=torch.float).uniform_(-std, std),
+ requires_grad=True)
+ # [num_codes]
+ self.scale = nn.Parameter(
+ torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+ requires_grad=True)
+
+ @staticmethod
+ def scaled_l2(x, codewords, scale):
+ num_codes, channels = codewords.size()
+ batch_size = x.size(0)
+ reshaped_scale = scale.view((1, 1, num_codes))
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+
+ scaled_l2_norm = reshaped_scale * (
+ expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+ return scaled_l2_norm
+
+ @staticmethod
+ def aggregate(assignment_weights, x, codewords):
+ num_codes, channels = codewords.size()
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+ batch_size = x.size(0)
+
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ encoded_feat = (assignment_weights.unsqueeze(3) *
+ (expanded_x - reshaped_codewords)).sum(dim=1)
+ return encoded_feat
+
+ def forward(self, x):
+ assert x.dim() == 4 and x.size(1) == self.channels
+ # [batch_size, channels, height, width]
+ batch_size = x.size(0)
+ # [batch_size, height x width, channels]
+ x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+ # assignment_weights: [batch_size, channels, num_codes]
+ assignment_weights = F.softmax(
+ self.scaled_l2(x, self.codewords, self.scale), dim=2)
+ # aggregate
+ encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+ return encoded_feat
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+ f'x{self.channels})'
+ return repr_str
diff --git a/annotator/uniformer/mmseg/ops/wrappers.py b/annotator/uniformer/mmseg/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/wrappers.py
@@ -0,0 +1,50 @@
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def resize(input,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None,
+ warning=True):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Upsample(nn.Module):
+
+ def __init__(self,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None):
+ super(Upsample, self).__init__()
+ self.size = size
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ if not self.size:
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+ else:
+ size = self.size
+ return resize(x, size, None, self.mode, self.align_corners)
diff --git a/annotator/uniformer/mmseg/utils/__init__.py b/annotator/uniformer/mmseg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+
+__all__ = ['get_root_logger', 'collect_env']
diff --git a/annotator/uniformer/mmseg/utils/collect_env.py b/annotator/uniformer/mmseg/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c2134ddbee9655161237dd0894d38c768c2624
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/collect_env.py
@@ -0,0 +1,17 @@
+from annotator.uniformer.mmcv.utils import collect_env as collect_base_env
+from annotator.uniformer.mmcv.utils import get_git_hash
+
+import annotator.uniformer.mmseg as mmseg
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print('{}: {}'.format(name, val))
diff --git a/annotator/uniformer/mmseg/utils/logger.py b/annotator/uniformer/mmseg/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..4149d9eda3dfef07490352d22ac40c42460315e4
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/logger.py
@@ -0,0 +1,27 @@
+import logging
+
+from annotator.uniformer.mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmseg".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+
+ logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
+
+ return logger
diff --git a/annotator/util.py b/annotator/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05
--- /dev/null
+++ b/annotator/util.py
@@ -0,0 +1,38 @@
+import numpy as np
+import cv2
+import os
+
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
diff --git a/cldm/__pycache__/cldm.cpython-38.pyc b/cldm/__pycache__/cldm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9b5ba5876096c4095e1d75a18aafda94cbac922
Binary files /dev/null and b/cldm/__pycache__/cldm.cpython-38.pyc differ
diff --git a/cldm/__pycache__/ddim_hacked.cpython-38.pyc b/cldm/__pycache__/ddim_hacked.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8223207531735b0f5358e75db151b1008c637a83
Binary files /dev/null and b/cldm/__pycache__/ddim_hacked.cpython-38.pyc differ
diff --git a/cldm/__pycache__/hack.cpython-38.pyc b/cldm/__pycache__/hack.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4ee0f5cd38fe089a75cc0402e0cf3ea026e1eca
Binary files /dev/null and b/cldm/__pycache__/hack.cpython-38.pyc differ
diff --git a/cldm/__pycache__/model.cpython-38.pyc b/cldm/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e359881618fa8dc1672b808a5944c31e7d10bb85
Binary files /dev/null and b/cldm/__pycache__/model.cpython-38.pyc differ
diff --git a/cldm/cldm.py b/cldm/cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b3ac7a575cf4933fc14dfc15dd3cca41cb3f3e8
--- /dev/null
+++ b/cldm/cldm.py
@@ -0,0 +1,435 @@
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+
+from ldm.modules.diffusionmodules.util import (
+ conv_nd,
+ linear,
+ zero_module,
+ timestep_embedding,
+)
+
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from ldm.modules.attention import SpatialTransformer
+from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, exists, instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+class ControlledUnetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
+ hs = []
+ with torch.no_grad():
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+
+ if control is not None:
+ h += control.pop()
+
+ for i, module in enumerate(self.output_blocks):
+ if only_mid_control or control is None:
+ h = torch.cat([h, hs.pop()], dim=1)
+ else:
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+ h = module(h, emb, context)
+
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class ControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ hint_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+ self.input_hint_block = TimestepEmbedSequential(
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch))
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self.middle_block_out = self.make_zero_conv(ch)
+ self._feature_size += ch
+
+ def make_zero_conv(self, channels):
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+ def forward(self, x, hint, timesteps, context, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ guided_hint = self.input_hint_block(hint, emb, context)
+
+ outs = []
+
+ h = x.type(self.dtype)
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ outs.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ outs.append(self.middle_block_out(h, emb, context))
+
+ return outs
+
+
+class ControlLDM(LatentDiffusion):
+
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.control_model = instantiate_from_config(control_stage_config)
+ self.control_key = control_key
+ self.only_mid_control = only_mid_control
+ self.control_scales = [1.0] * 13
+
+ @torch.no_grad()
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
+ control = batch[self.control_key]
+ if bs is not None:
+ control = control[:bs]
+ control = control.to(self.device)
+ control = einops.rearrange(control, 'b h w c -> b c h w')
+ control = control.to(memory_format=torch.contiguous_format).float()
+ return x, dict(c_crossattn=[c], c_concat=[control])
+
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+ assert isinstance(cond, dict)
+ diffusion_model = self.model.diffusion_model
+
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
+
+ if cond['c_concat'] is None:
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
+ else:
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
+
+ return eps
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, N):
+ return self.get_learned_conditioning([""] * N)
+
+ @torch.no_grad()
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
+ N = min(z.shape[0], N)
+ n_row = min(z.shape[0], n_row)
+ log["reconstruction"] = self.decode_first_stage(z)
+ log["control"] = c_cat * 2.0 - 1.0
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N)
+ uc_cat = c_cat # torch.zeros_like(c_cat)
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ return log
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ ddim_sampler = DDIMSampler(self)
+ b, c, h, w = cond["c_concat"][0].shape
+ shape = (self.channels, h // 8, w // 8)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+ return samples, intermediates
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.control_model.parameters())
+ if not self.sd_locked:
+ params += list(self.model.diffusion_model.output_blocks.parameters())
+ params += list(self.model.diffusion_model.out.parameters())
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+ def low_vram_shift(self, is_diffusing):
+ if is_diffusing:
+ self.model = self.model.cuda()
+ self.control_model = self.control_model.cuda()
+ self.first_stage_model = self.first_stage_model.cpu()
+ self.cond_stage_model = self.cond_stage_model.cpu()
+ else:
+ self.model = self.model.cpu()
+ self.control_model = self.control_model.cpu()
+ self.first_stage_model = self.first_stage_model.cuda()
+ self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..25b1bc947272ad14d7f7e5e4d1809005253b63d0
--- /dev/null
+++ b/cldm/ddim_hacked.py
@@ -0,0 +1,317 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ model_t = self.model.apply_model(x, t, c)
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ num_reference_steps = timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
diff --git a/cldm/hack.py b/cldm/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1
--- /dev/null
+++ b/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import ldm.modules.encoders.modules
+import ldm.modules.attention
+
+from transformers import logging
+from ldm.modules.attention import default
+
+
+def disable_verbosity():
+ logging.set_verbosity_error()
+ print('logging improved.')
+ return
+
+
+def enable_sliced_attention():
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+ print('Enabled sliced_attention.')
+ return
+
+
+def hack_everything(clip_skip=0):
+ disable_verbosity()
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+ print('Enabled clip hacks.')
+ return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+ PAD = self.tokenizer.pad_token_id
+ EOS = self.tokenizer.eos_token_id
+ BOS = self.tokenizer.bos_token_id
+
+ def tokenize(t):
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+ def transformer_encode(t):
+ if self.clip_skip > 1:
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+ else:
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+ def split(x):
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+ def pad(x, p, i):
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+ raw_tokens_list = tokenize(text)
+ tokens_list = []
+
+ for raw_tokens in raw_tokens_list:
+ raw_tokens_123 = split(raw_tokens)
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+ tokens_list.append(raw_tokens_123)
+
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+ y = transformer_encode(feed)
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+ return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ limit = k.shape[0]
+ att_step = 1
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range(0, limit, att_step):
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+
+ sim_buffer = sim_buffer.softmax(dim=-1)
+
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i + att_step, :, :] = sim_buffer
+
+ del sim_buffer
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
diff --git a/cldm/logger.py b/cldm/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f
--- /dev/null
+++ b/cldm/logger.py
@@ -0,0 +1,76 @@
+import os
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+
+class ImageLogger(Callback):
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+ log_images_kwargs=None):
+ super().__init__()
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+
+ @rank_zero_only
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+ root = os.path.join(save_dir, "image_log", split)
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
+ hasattr(pl_module, "log_images") and
+ callable(pl_module.log_images) and
+ self.max_images > 0):
+ logger = type(pl_module.logger)
+
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ with torch.no_grad():
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().cpu()
+ if self.clamp:
+ images[k] = torch.clamp(images[k], -1., 1.)
+
+ self.log_local(pl_module.logger.save_dir, split, images,
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx):
+ return check_idx % self.batch_freq == 0
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ if not self.disabled:
+ self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/cldm/model.py b/cldm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed3c31ac145b78907c7f771d1d8db6fb32d92ed
--- /dev/null
+++ b/cldm/model.py
@@ -0,0 +1,28 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+ return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+ state_dict = get_state_dict(state_dict)
+ print(f'Loaded state_dict from [{ckpt_path}]')
+ return state_dict
+
+
+def create_model(config_path):
+ config = OmegaConf.load(config_path)
+ model = instantiate_from_config(config.model).cpu()
+ print(f'Loaded model config from [{config_path}]')
+ return model
diff --git a/config.py b/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0c738d8cbad66bbe1666284aef926c326849701
--- /dev/null
+++ b/config.py
@@ -0,0 +1 @@
+save_memory = False
diff --git a/diffusers/__init__.py b/diffusers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..445e63b4d406f3a32f2ee04556a6613d2646893e
--- /dev/null
+++ b/diffusers/__init__.py
@@ -0,0 +1,289 @@
+__version__ = "0.19.0.dev0"
+
+from .configuration_utils import ConfigMixin
+from .utils import (
+ OptionalDependencyNotAvailable,
+ is_flax_available,
+ is_inflect_available,
+ is_invisible_watermark_available,
+ is_k_diffusion_available,
+ is_k_diffusion_version,
+ is_librosa_available,
+ is_note_seq_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_torch_available,
+ is_torchsde_available,
+ is_transformers_available,
+ is_transformers_version,
+ is_unidecode_available,
+ logging,
+)
+
+
+try:
+ if not is_onnx_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_onnx_objects import * # noqa F403
+else:
+ from .pipelines import OnnxRuntimeModel
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_pt_objects import * # noqa F403
+else:
+ from .models import (
+ AutoencoderKL,
+ ControlNetModel,
+ ModelMixin,
+ MultiAdapter,
+ PriorTransformer,
+ T2IAdapter,
+ T5FilmDecoder,
+ Transformer2DModel,
+ UNet1DModel,
+ UNet2DConditionModel,
+ UNet2DModel,
+ UNet3DConditionModel,
+ VQModel,
+ )
+ from .optimization import (
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+ )
+ from .pipelines import (
+ AudioPipelineOutput,
+ ConsistencyModelPipeline,
+ DanceDiffusionPipeline,
+ DDIMPipeline,
+ DDPMPipeline,
+ DiffusionPipeline,
+ DiTPipeline,
+ ImagePipelineOutput,
+ KarrasVePipeline,
+ LDMPipeline,
+ LDMSuperResolutionPipeline,
+ PNDMPipeline,
+ RePaintPipeline,
+ ScoreSdeVePipeline,
+ )
+ from .schedulers import (
+ CMStochasticIterativeScheduler,
+ DDIMInverseScheduler,
+ DDIMParallelScheduler,
+ DDIMScheduler,
+ DDPMParallelScheduler,
+ DDPMScheduler,
+ DEISMultistepScheduler,
+ DPMSolverMultistepInverseScheduler,
+ DPMSolverMultistepScheduler,
+ DPMSolverSinglestepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ IPNDMScheduler,
+ KarrasVeScheduler,
+ KDPM2AncestralDiscreteScheduler,
+ KDPM2DiscreteScheduler,
+ PNDMScheduler,
+ RePaintScheduler,
+ SchedulerMixin,
+ ScoreSdeVeScheduler,
+ UnCLIPScheduler,
+ UniPCMultistepScheduler,
+ VQDiffusionScheduler,
+ )
+ from .training_utils import EMAModel
+
+try:
+ if not (is_torch_available() and is_scipy_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
+else:
+ from .schedulers import LMSDiscreteScheduler
+
+try:
+ if not (is_torch_available() and is_torchsde_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
+else:
+ from .schedulers import DPMSolverSDEScheduler
+
+try:
+ if not (is_torch_available() and is_transformers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipelines import (
+ AltDiffusionImg2ImgPipeline,
+ AltDiffusionPipeline,
+ AudioLDMPipeline,
+ CycleDiffusionPipeline,
+ IFImg2ImgPipeline,
+ IFImg2ImgSuperResolutionPipeline,
+ IFInpaintingPipeline,
+ IFInpaintingSuperResolutionPipeline,
+ IFPipeline,
+ IFSuperResolutionPipeline,
+ ImageTextPipelineOutput,
+ KandinskyImg2ImgPipeline,
+ KandinskyInpaintPipeline,
+ KandinskyPipeline,
+ KandinskyPriorPipeline,
+ KandinskyV22ControlnetImg2ImgPipeline,
+ KandinskyV22ControlnetPipeline,
+ KandinskyV22Img2ImgPipeline,
+ KandinskyV22InpaintPipeline,
+ KandinskyV22Pipeline,
+ KandinskyV22PriorEmb2EmbPipeline,
+ KandinskyV22PriorPipeline,
+ LDMTextToImagePipeline,
+ PaintByExamplePipeline,
+ SemanticStableDiffusionPipeline,
+ ShapEImg2ImgPipeline,
+ ShapEPipeline,
+ StableDiffusionAdapterPipeline,
+ StableDiffusionAttendAndExcitePipeline,
+ StableDiffusionControlNetImg2ImgPipeline,
+ StableDiffusionControlNetInpaintPipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionDepth2ImgPipeline,
+ StableDiffusionDiffEditPipeline,
+ StableDiffusionImageVariationPipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionInstructPix2PixPipeline,
+ StableDiffusionLatentUpscalePipeline,
+ StableDiffusionLDM3DPipeline,
+ StableDiffusionModelEditingPipeline,
+ StableDiffusionPanoramaPipeline,
+ StableDiffusionParadigmsPipeline,
+ StableDiffusionPipeline,
+ StableDiffusionPipelineSafe,
+ StableDiffusionPix2PixZeroPipeline,
+ StableDiffusionSAGPipeline,
+ StableDiffusionUpscalePipeline,
+ StableUnCLIPImg2ImgPipeline,
+ StableUnCLIPPipeline,
+ TextToVideoSDPipeline,
+ TextToVideoZeroPipeline,
+ UnCLIPImageVariationPipeline,
+ UnCLIPPipeline,
+ UniDiffuserModel,
+ UniDiffuserPipeline,
+ UniDiffuserTextDecoder,
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ VideoToVideoSDPipeline,
+ VQDiffusionPipeline,
+ )
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
+else:
+ from .pipelines import (
+ StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLImg2ImgPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLPipeline,
+ )
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
+else:
+ from .pipelines import StableDiffusionKDiffusionPipeline
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
+else:
+ from .pipelines import (
+ OnnxStableDiffusionImg2ImgPipeline,
+ OnnxStableDiffusionInpaintPipeline,
+ OnnxStableDiffusionInpaintPipelineLegacy,
+ OnnxStableDiffusionPipeline,
+ OnnxStableDiffusionUpscalePipeline,
+ StableDiffusionOnnxPipeline,
+ )
+
+try:
+ if not (is_torch_available() and is_librosa_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
+else:
+ from .pipelines import AudioDiffusionPipeline, Mel
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
+else:
+ from .pipelines import SpectrogramDiffusionPipeline
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_flax_objects import * # noqa F403
+else:
+ from .models.controlnet_flax import FlaxControlNetModel
+ from .models.modeling_flax_utils import FlaxModelMixin
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
+ from .models.vae_flax import FlaxAutoencoderKL
+ from .pipelines import FlaxDiffusionPipeline
+ from .schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDDPMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxKarrasVeScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+ FlaxSchedulerMixin,
+ FlaxScoreSdeVeScheduler,
+ )
+
+
+try:
+ if not (is_flax_available() and is_transformers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
+else:
+ from .pipelines import (
+ FlaxStableDiffusionControlNetPipeline,
+ FlaxStableDiffusionImg2ImgPipeline,
+ FlaxStableDiffusionInpaintPipeline,
+ FlaxStableDiffusionPipeline,
+ )
+
+try:
+ if not (is_note_seq_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils.dummy_note_seq_objects import * # noqa F403
+else:
+ from .pipelines import MidiProcessor
diff --git a/diffusers/__pycache__/__init__.cpython-310.pyc b/diffusers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17db3d7b74eb1088c064f01e3ad8d93976b5cafb
Binary files /dev/null and b/diffusers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/__init__.cpython-38.pyc b/diffusers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27ee15d5a0c624e28fb039c9866b4c7b8fdba2ac
Binary files /dev/null and b/diffusers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/__pycache__/configuration_utils.cpython-310.pyc b/diffusers/__pycache__/configuration_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05f83a659005e09d5efacdb9f6e30232e5727e1d
Binary files /dev/null and b/diffusers/__pycache__/configuration_utils.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/configuration_utils.cpython-38.pyc b/diffusers/__pycache__/configuration_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b24cab046e3e49de39f21b88ef15b8a96e09ef8
Binary files /dev/null and b/diffusers/__pycache__/configuration_utils.cpython-38.pyc differ
diff --git a/diffusers/__pycache__/image_processor.cpython-310.pyc b/diffusers/__pycache__/image_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4ce9c9c22a7e849d546cdeed1f7c94cb83d41a9
Binary files /dev/null and b/diffusers/__pycache__/image_processor.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/image_processor.cpython-38.pyc b/diffusers/__pycache__/image_processor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3598363605f0c1ec2bdba08dd1f14d173517cc2d
Binary files /dev/null and b/diffusers/__pycache__/image_processor.cpython-38.pyc differ
diff --git a/diffusers/__pycache__/loaders.cpython-310.pyc b/diffusers/__pycache__/loaders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1318e3ce33889b7df3f51d97bd67ab7fc9e29c98
Binary files /dev/null and b/diffusers/__pycache__/loaders.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/loaders.cpython-38.pyc b/diffusers/__pycache__/loaders.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15cde8a510b2473cad8a4b01fe63c3a92b6e891d
Binary files /dev/null and b/diffusers/__pycache__/loaders.cpython-38.pyc differ
diff --git a/diffusers/__pycache__/optimization.cpython-310.pyc b/diffusers/__pycache__/optimization.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb1d462d18b8d02a7b61f8ca1ca9141546a0c8bc
Binary files /dev/null and b/diffusers/__pycache__/optimization.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/optimization.cpython-38.pyc b/diffusers/__pycache__/optimization.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b58142d3b7fe20008979656b90386f07f27aae
Binary files /dev/null and b/diffusers/__pycache__/optimization.cpython-38.pyc differ
diff --git a/diffusers/__pycache__/training_utils.cpython-310.pyc b/diffusers/__pycache__/training_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74fe61cfa35f11e3848689a8c493b8d0a4bc750a
Binary files /dev/null and b/diffusers/__pycache__/training_utils.cpython-310.pyc differ
diff --git a/diffusers/__pycache__/training_utils.cpython-38.pyc b/diffusers/__pycache__/training_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13347a7ba998f3c1dfbfe0b1c97beae04a1b4d7c
Binary files /dev/null and b/diffusers/__pycache__/training_utils.cpython-38.pyc differ
diff --git a/diffusers/commands/__init__.py b/diffusers/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad4af9199bbe297dbc6679fd9ecb46baa976053
--- /dev/null
+++ b/diffusers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseDiffusersCLICommand(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_subcommand(parser: ArgumentParser):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def run(self):
+ raise NotImplementedError()
diff --git a/diffusers/commands/diffusers_cli.py b/diffusers/commands/diffusers_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..2016fc19f557fd539782ca2181ec2fe74026340a
--- /dev/null
+++ b/diffusers/commands/diffusers_cli.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from argparse import ArgumentParser
+
+from .env import EnvironmentCommand
+from .fp16_safetensors import FP16SafetensorsCommand
+
+
+def main():
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
+
+ # Register commands
+ EnvironmentCommand.register_subcommand(commands_parser)
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
+
+ # Let's go
+ args = parser.parse_args()
+
+ if not hasattr(args, "func"):
+ parser.print_help()
+ exit(1)
+
+ # Run
+ service = args.func(args)
+ service.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/commands/env.py b/diffusers/commands/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..db9de720942b5efcff921d7e2503e3ae8813561e
--- /dev/null
+++ b/diffusers/commands/env.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from argparse import ArgumentParser
+
+import huggingface_hub
+
+from .. import __version__ as version
+from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
+from . import BaseDiffusersCLICommand
+
+
+def info_command_factory(_):
+ return EnvironmentCommand()
+
+
+class EnvironmentCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ download_parser = parser.add_parser("env")
+ download_parser.set_defaults(func=info_command_factory)
+
+ def run(self):
+ hub_version = huggingface_hub.__version__
+
+ pt_version = "not installed"
+ pt_cuda_available = "NA"
+ if is_torch_available():
+ import torch
+
+ pt_version = torch.__version__
+ pt_cuda_available = torch.cuda.is_available()
+
+ transformers_version = "not installed"
+ if is_transformers_available():
+ import transformers
+
+ transformers_version = transformers.__version__
+
+ accelerate_version = "not installed"
+ if is_accelerate_available():
+ import accelerate
+
+ accelerate_version = accelerate.__version__
+
+ xformers_version = "not installed"
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = xformers.__version__
+
+ info = {
+ "`diffusers` version": version,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
+ "Huggingface_hub version": hub_version,
+ "Transformers version": transformers_version,
+ "Accelerate version": accelerate_version,
+ "xFormers version": xformers_version,
+ "Using GPU in script?": "",
+ "Using distributed or parallel set-up in script?": "",
+ }
+
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
+ print(self.format_dict(info))
+
+ return info
+
+ @staticmethod
+ def format_dict(d):
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diff --git a/diffusers/commands/fp16_safetensors.py b/diffusers/commands/fp16_safetensors.py
new file mode 100644
index 0000000000000000000000000000000000000000..19553c752dce116d01f9816f90ddd3275d8cc302
--- /dev/null
+++ b/diffusers/commands/fp16_safetensors.py
@@ -0,0 +1,138 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage example:
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
+"""
+
+import glob
+import json
+from argparse import ArgumentParser, Namespace
+from importlib import import_module
+
+import huggingface_hub
+import torch
+from huggingface_hub import hf_hub_download
+from packaging import version
+
+from ..utils import is_safetensors_available, logging
+from . import BaseDiffusersCLICommand
+
+
+def conversion_command_factory(args: Namespace):
+ return FP16SafetensorsCommand(
+ args.ckpt_id,
+ args.fp16,
+ args.use_safetensors,
+ args.use_auth_token,
+ )
+
+
+class FP16SafetensorsCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ conversion_parser = parser.add_parser("fp16_safetensors")
+ conversion_parser.add_argument(
+ "--ckpt_id",
+ type=str,
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
+ )
+ conversion_parser.add_argument(
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
+ )
+ conversion_parser.add_argument(
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
+ )
+ conversion_parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
+ )
+ conversion_parser.set_defaults(func=conversion_command_factory)
+
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
+ self.ckpt_id = ckpt_id
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
+ self.fp16 = fp16
+
+ if is_safetensors_available():
+ self.use_safetensors = use_safetensors
+ else:
+ raise ImportError(
+ "When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
+ )
+
+ if not self.use_safetensors and not self.fp16:
+ raise NotImplementedError(
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
+ )
+
+ self.use_auth_token = use_auth_token
+
+ def run(self):
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
+ raise ImportError(
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
+ " installation."
+ )
+ else:
+ from huggingface_hub import create_commit
+ from huggingface_hub._commit_api import CommitOperationAdd
+
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
+ with open(model_index, "r") as f:
+ pipeline_class_name = json.load(f)["_class_name"]
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
+
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
+ # here, but just to avoid any rough edge cases.
+ pipeline = pipeline_class.from_pretrained(
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
+ )
+ pipeline.save_pretrained(
+ self.local_ckpt_dir,
+ safe_serialization=True if self.use_safetensors else False,
+ variant="fp16" if self.fp16 else None,
+ )
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
+
+ # Fetch all the paths.
+ if self.fp16:
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
+ elif self.use_safetensors:
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
+
+ # Prepare for the PR.
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
+ operations = []
+ for path in modified_paths:
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
+
+ # Open the PR.
+ commit_description = (
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
+ )
+ hub_pr_url = create_commit(
+ repo_id=self.ckpt_id,
+ operations=operations,
+ commit_message=commit_message,
+ commit_description=commit_description,
+ repo_type="model",
+ create_pr=True,
+ ).pr_url
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diff --git a/diffusers/configuration_utils.py b/diffusers/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5c8e8919c9fcd48de5a89e0664bd6c00643f515
--- /dev/null
+++ b/diffusers/configuration_utils.py
@@ -0,0 +1,664 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixin base class and utilities."""
+import dataclasses
+import functools
+import importlib
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from pathlib import PosixPath
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import (
+ DIFFUSERS_CACHE,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ DummyObject,
+ deprecate,
+ extract_commit_hash,
+ http_user_agent,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
+ saving classes that inherit from [`ConfigMixin`].
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by subclass).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+ subclass).
+ """
+ config_name = None
+ ignore_for_config = []
+ has_compatibles = False
+
+ _deprecated_kwargs = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def __getattr__(self, name: str) -> Any:
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
+
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
+ """
+
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
+ is_attribute = name in self.__dict__
+
+ if is_in_config and not is_attribute:
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
+ return self._internal_dict[name]
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a config dictionary.
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
+ files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
+ overwrite the same named arguments in `config`.
+
+ Returns:
+ [`ModelMixin`] or [`SchedulerMixin`]:
+ A model or scheduler object instantiated from a config dictionary.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
+ # ======>
+
+ if not isinstance(config, dict):
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
+ if "Scheduler" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
+ " be removed in v1.0.0."
+ )
+ elif "Model" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
+ " instead. This functionality will be removed in v1.0.0."
+ )
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
+
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
+
+ # Allow dtype to be specified on initialization
+ if "dtype" in unused_kwargs:
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
+
+ # add possible deprecated kwargs
+ for deprecated_kwarg in cls._deprecated_kwargs:
+ if deprecated_kwarg in unused_kwargs:
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**init_dict)
+
+ # make sure to also save config parameters that might be used for compatible classes
+ model.register_to_config(**hidden_dict)
+
+ # add hidden kwargs of compatible classes to unused_kwargs
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
+
+ if return_unused_kwargs:
+ return (model, unused_kwargs)
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(cls, *args, **kwargs):
+ deprecation_message = (
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
+ " removed in version v1.0.0"
+ )
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
+ return cls.load_config(*args, **kwargs)
+
+ @classmethod
+ def load_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ return_commit_hash=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Load a model or scheduler configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
+ [`~ConfigMixin.save_config`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
+ Whether unused keyword arguments of the config are returned.
+ return_commit_hash (`bool`, *optional*, defaults to `False):
+ Whether the `commit_hash` of the loaded configuration are returned.
+
+ Returns:
+ `dict`:
+ A dictionary of all the parameters stored in a JSON configuration file.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ _ = kwargs.pop("mirror", None)
+ subfolder = kwargs.pop("subfolder", None)
+ user_agent = kwargs.pop("user_agent", {})
+
+ user_agent = {**user_agent, "file_type": "config"}
+ user_agent = http_user_agent(user_agent)
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+
+ commit_hash = extract_commit_hash(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if not (return_unused_kwargs or return_commit_hash):
+ return config_dict
+
+ outputs = (config_dict,)
+
+ if return_unused_kwargs:
+ outputs += (kwargs,)
+
+ if return_commit_hash:
+ outputs += (commit_hash,)
+
+ return outputs
+
+ @staticmethod
+ def _get_init_keys(cls):
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ # Skip keys that were not present in the original config, so default __init__ values were used
+ used_defaults = config_dict.get("_use_default_values", [])
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
+
+ # 0. Copy origin config dict
+ original_dict = dict(config_dict.items())
+
+ # 1. Retrieve expected config attributes from __init__ signature
+ expected_keys = cls._get_init_keys(cls)
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove flax internal keys
+ if hasattr(cls, "_flax_internal_args"):
+ for arg in cls._flax_internal_args:
+ expected_keys.remove(arg)
+
+ # 2. Remove attributes that cannot be expected from expected config attributes
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+
+ # load diffusers library to import compatible and original scheduler
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+
+ if cls.has_compatibles:
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
+ else:
+ compatible_classes = []
+
+ expected_keys_comp_cls = set()
+ for c in compatible_classes:
+ expected_keys_c = cls._get_init_keys(c)
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
+
+ # remove attributes from orig class that cannot be expected
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
+ orig_cls = getattr(diffusers_library, orig_cls_name)
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
+
+ # remove private attributes
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
+ init_dict = {}
+ for key in expected_keys:
+ # if config param is passed to kwarg and is present in config dict
+ # it should overwrite existing config dict key
+ if key in kwargs and key in config_dict:
+ config_dict[key] = kwargs.pop(key)
+
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ # 4. Give nice warning if unexpected values have been passed
+ if len(config_dict) > 0:
+ logger.warning(
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
+ "but are not expected and will be ignored. Please verify your "
+ f"{cls.config_name} configuration file."
+ )
+
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.info(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ # 6. Define unused keyword arguments
+ unused_kwargs = {**config_dict, **kwargs}
+
+ # 7. Define "hidden" config parameters that were saved for compatible classes
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
+
+ return init_dict, unused_kwargs, hidden_config_dict
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes the configuration instance to a JSON string.
+
+ Returns:
+ `str`:
+ String containing all the attributes that make up the configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ config_dict["_class_name"] = self.__class__.__name__
+ config_dict["_diffusers_version"] = __version__
+
+ def to_json_saveable(value):
+ if isinstance(value, np.ndarray):
+ value = value.tolist()
+ elif isinstance(value, PosixPath):
+ value = str(value)
+ return value
+
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+ # Don't save "_ignore_files" or "_use_default_values"
+ config_dict.pop("_ignore_files", None)
+ config_dict.pop("_use_default_values", None)
+
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save the configuration instance's parameters to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file to save a configuration instance's parameters.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+
+ # Take note of the parameters that were not present in the loaded config
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
+
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
+ getattr(self, "register_to_config")(**new_kwargs)
+ init(self, *args, **init_kwargs)
+
+ return inner_init
+
+
+def flax_register_to_config(cls):
+ original_init = cls.__init__
+
+ @functools.wraps(original_init)
+ def init(self, *args, **kwargs):
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ # Ignore private kwargs in the init. Retrieve all passed attributes
+ init_kwargs = dict(kwargs.items())
+
+ # Retrieve default values
+ fields = dataclasses.fields(self)
+ default_kwargs = {}
+ for field in fields:
+ # ignore flax specific attributes
+ if field.name in self._flax_internal_args:
+ continue
+ if type(field.default) == dataclasses._MISSING_TYPE:
+ default_kwargs[field.name] = None
+ else:
+ default_kwargs[field.name] = getattr(self, field.name)
+
+ # Make sure init_kwargs override default kwargs
+ new_kwargs = {**default_kwargs, **init_kwargs}
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
+ if "dtype" in new_kwargs:
+ new_kwargs.pop("dtype")
+
+ # Get positional arguments aligned with kwargs
+ for i, arg in enumerate(args):
+ name = fields[i].name
+ new_kwargs[name] = arg
+
+ # Take note of the parameters that were not present in the loaded config
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
+
+ getattr(self, "register_to_config")(**new_kwargs)
+ original_init(self, *args, **kwargs)
+
+ cls.__init__ = init
+ return cls
diff --git a/diffusers/dependency_versions_check.py b/diffusers/dependency_versions_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f8578c52957bf6c06decb0d97d3139437f0078f
--- /dev/null
+++ b/diffusers/dependency_versions_check.py
@@ -0,0 +1,47 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/diffusers/dependency_versions_table.py b/diffusers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e6cb9e816ccf9ac5f558aeae4e7ce75a3c6feeb
--- /dev/null
+++ b/diffusers/dependency_versions_table.py
@@ -0,0 +1,44 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow",
+ "accelerate": "accelerate>=0.11.0",
+ "compel": "compel==0.1.8",
+ "black": "black~=23.1",
+ "datasets": "datasets",
+ "filelock": "filelock",
+ "flax": "flax>=0.4.1",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.13.2",
+ "requests-mock": "requests-mock==1.10.0",
+ "importlib_metadata": "importlib_metadata",
+ "invisible-watermark": "invisible-watermark>=0.2.0",
+ "isort": "isort>=5.5.4",
+ "jax": "jax>=0.2.8,!=0.3.2",
+ "jaxlib": "jaxlib>=0.1.65",
+ "Jinja2": "Jinja2",
+ "k-diffusion": "k-diffusion>=0.0.12",
+ "torchsde": "torchsde",
+ "note_seq": "note_seq",
+ "librosa": "librosa",
+ "numpy": "numpy",
+ "omegaconf": "omegaconf",
+ "parameterized": "parameterized",
+ "protobuf": "protobuf>=3.20.3,<4",
+ "pytest": "pytest",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "ruff": "ruff>=0.0.241",
+ "safetensors": "safetensors",
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+ "scipy": "scipy",
+ "onnx": "onnx",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "tensorboard": "tensorboard",
+ "torch": "torch>=1.4",
+ "torchvision": "torchvision",
+ "transformers": "transformers>=4.25.1",
+ "urllib3": "urllib3<=2.0.0",
+}
diff --git a/diffusers/experimental/README.md b/diffusers/experimental/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..81a9de81c73728ea41eb6e8617a5429c3c9645ff
--- /dev/null
+++ b/diffusers/experimental/README.md
@@ -0,0 +1,5 @@
+# 🧨 Diffusers Experimental
+
+We are adding experimental code to support novel applications and usages of the Diffusers library.
+Currently, the following experiments are supported:
+* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
\ No newline at end of file
diff --git a/diffusers/experimental/__init__.py b/diffusers/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc8155403016dfd8ad7fb78d246f9da9098ac50
--- /dev/null
+++ b/diffusers/experimental/__init__.py
@@ -0,0 +1 @@
+from .rl import ValueGuidedRLPipeline
diff --git a/diffusers/experimental/rl/__init__.py b/diffusers/experimental/rl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c
--- /dev/null
+++ b/diffusers/experimental/rl/__init__.py
@@ -0,0 +1 @@
+from .value_guided_sampling import ValueGuidedRLPipeline
diff --git a/diffusers/experimental/rl/value_guided_sampling.py b/diffusers/experimental/rl/value_guided_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4af4986faad9c1e81a5cf4ee76138f3db00ab44
--- /dev/null
+++ b/diffusers/experimental/rl/value_guided_sampling.py
@@ -0,0 +1,152 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+import tqdm
+
+from ...models.unet_1d import UNet1DModel
+from ...pipelines import DiffusionPipeline
+from ...utils import randn_tensor
+from ...utils.dummy_pt_objects import DDPMScheduler
+
+
+class ValueGuidedRLPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
+
+ Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
+
+ Parameters:
+ value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
+ application is [`DDPMScheduler`].
+ env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
+ """
+
+ def __init__(
+ self,
+ value_function: UNet1DModel,
+ unet: UNet1DModel,
+ scheduler: DDPMScheduler,
+ env,
+ ):
+ super().__init__()
+ self.value_function = value_function
+ self.unet = unet
+ self.scheduler = scheduler
+ self.env = env
+ self.data = env.get_dataset()
+ self.means = {}
+ for key in self.data.keys():
+ try:
+ self.means[key] = self.data[key].mean()
+ except: # noqa: E722
+ pass
+ self.stds = {}
+ for key in self.data.keys():
+ try:
+ self.stds[key] = self.data[key].std()
+ except: # noqa: E722
+ pass
+ self.state_dim = env.observation_space.shape[0]
+ self.action_dim = env.action_space.shape[0]
+
+ def normalize(self, x_in, key):
+ return (x_in - self.means[key]) / self.stds[key]
+
+ def de_normalize(self, x_in, key):
+ return x_in * self.stds[key] + self.means[key]
+
+ def to_torch(self, x_in):
+ if type(x_in) is dict:
+ return {k: self.to_torch(v) for k, v in x_in.items()}
+ elif torch.is_tensor(x_in):
+ return x_in.to(self.unet.device)
+ return torch.tensor(x_in, device=self.unet.device)
+
+ def reset_x0(self, x_in, cond, act_dim):
+ for key, val in cond.items():
+ x_in[:, key, act_dim:] = val.clone()
+ return x_in
+
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
+ batch_size = x.shape[0]
+ y = None
+ for i in tqdm.tqdm(self.scheduler.timesteps):
+ # create batch of timesteps to pass into model
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
+ for _ in range(n_guide_steps):
+ with torch.enable_grad():
+ x.requires_grad_()
+
+ # permute to match dimension for pre-trained models
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
+ grad = torch.autograd.grad([y.sum()], [x])[0]
+
+ posterior_variance = self.scheduler._get_variance(i)
+ model_std = torch.exp(0.5 * posterior_variance)
+ grad = model_std * grad
+
+ grad[timesteps < 2] = 0
+ x = x.detach()
+ x = x + scale * grad
+ x = self.reset_x0(x, conditions, self.action_dim)
+
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
+
+ # TODO: verify deprecation of this kwarg
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
+
+ # apply conditions to the trajectory (set the initial state)
+ x = self.reset_x0(x, conditions, self.action_dim)
+ x = self.to_torch(x)
+ return x, y
+
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
+ # normalize the observations and create batch dimension
+ obs = self.normalize(obs, "observations")
+ obs = obs[None].repeat(batch_size, axis=0)
+
+ conditions = {0: self.to_torch(obs)}
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
+
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
+ x1 = randn_tensor(shape, device=self.unet.device)
+ x = self.reset_x0(x1, conditions, self.action_dim)
+ x = self.to_torch(x)
+
+ # run the diffusion process
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
+
+ # sort output trajectories by value
+ sorted_idx = y.argsort(0, descending=True).squeeze()
+ sorted_values = x[sorted_idx]
+ actions = sorted_values[:, :, : self.action_dim]
+ actions = actions.detach().cpu().numpy()
+ denorm_actions = self.de_normalize(actions, key="actions")
+
+ # select the action with the highest value
+ if y is not None:
+ selected_index = 0
+ else:
+ # if we didn't run value guiding, select a random action
+ selected_index = np.random.randint(0, batch_size)
+
+ denorm_actions = denorm_actions[selected_index, 0]
+ return denorm_actions
diff --git a/diffusers/image_processor.py b/diffusers/image_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ccf9b465ebd4cd6ce48a40dfe45bbc70d1f3416
--- /dev/null
+++ b/diffusers/image_processor.py
@@ -0,0 +1,366 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+
+from .configuration_utils import ConfigMixin, register_to_config
+from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
+
+
+class VaeImageProcessor(ConfigMixin):
+ """
+ Image processor for VAE.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to RGB format.
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ resample: str = "lanczos",
+ do_normalize: bool = True,
+ do_convert_rgb: bool = False,
+ ):
+ super().__init__()
+
+ @staticmethod
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ @staticmethod
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
+ """
+ Convert a PIL image or a list of PIL images to NumPy arrays.
+ """
+ if not isinstance(images, list):
+ images = [images]
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
+ images = np.stack(images, axis=0)
+
+ return images
+
+ @staticmethod
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
+ """
+ Convert a NumPy image to a PyTorch tensor.
+ """
+ if images.ndim == 3:
+ images = images[..., None]
+
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
+ return images
+
+ @staticmethod
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
+ """
+ Convert a PyTorch tensor to a NumPy image.
+ """
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
+ return images
+
+ @staticmethod
+ def normalize(images):
+ """
+ Normalize an image array to [-1,1].
+ """
+ return 2.0 * images - 1.0
+
+ @staticmethod
+ def denormalize(images):
+ """
+ Denormalize an image array to [0,1].
+ """
+ return (images / 2 + 0.5).clamp(0, 1)
+
+ @staticmethod
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
+ """
+ Converts an image to RGB format.
+ """
+ image = image.convert("RGB")
+ return image
+
+ def resize(
+ self,
+ image: PIL.Image.Image,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ) -> PIL.Image.Image:
+ """
+ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
+ """
+ if height is None:
+ height = image.height
+ if width is None:
+ width = image.width
+
+ width, height = (
+ x - x % self.config.vae_scale_factor for x in (width, height)
+ ) # resize to integer multiple of vae_scale_factor
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
+ return image
+
+ def preprocess(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
+ """
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
+ if isinstance(image, supported_formats):
+ image = [image]
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
+ raise ValueError(
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
+ )
+
+ if isinstance(image[0], PIL.Image.Image):
+ if self.config.do_convert_rgb:
+ image = [self.convert_to_rgb(i) for i in image]
+ if self.config.do_resize:
+ image = [self.resize(i, height, width) for i in image]
+ image = self.pil_to_numpy(image) # to np
+ image = self.numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+ image = self.numpy_to_pt(image)
+ _, _, height, width = image.shape
+ if self.config.do_resize and (
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
+ ):
+ raise ValueError(
+ f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
+ )
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+ _, channel, height, width = image.shape
+
+ # don't need any preprocess if the image is latents
+ if channel == 4:
+ return image
+
+ if self.config.do_resize and (
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
+ ):
+ raise ValueError(
+ f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
+ )
+
+ # expected range [0,1], normalize to [-1,1]
+ do_normalize = self.config.do_normalize
+ if image.min() < 0:
+ warnings.warn(
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
+ FutureWarning,
+ )
+ do_normalize = False
+
+ if do_normalize:
+ image = self.normalize(image)
+
+ return image
+
+ def postprocess(
+ self,
+ image: torch.FloatTensor,
+ output_type: str = "pil",
+ do_denormalize: Optional[List[bool]] = None,
+ ):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
+ )
+ if output_type not in ["latent", "pt", "np", "pil"]:
+ deprecation_message = (
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
+ "`pil`, `np`, `pt`, `latent`"
+ )
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
+ output_type = "np"
+
+ if output_type == "latent":
+ return image
+
+ if do_denormalize is None:
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
+
+ image = torch.stack(
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
+ )
+
+ if output_type == "pt":
+ return image
+
+ image = self.pt_to_numpy(image)
+
+ if output_type == "np":
+ return image
+
+ if output_type == "pil":
+ return self.numpy_to_pil(image)
+
+
+class VaeImageProcessorLDM3D(VaeImageProcessor):
+ """
+ Image processor for VAE LDM3D.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ resample: str = "lanczos",
+ do_normalize: bool = True,
+ ):
+ super().__init__()
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a NumPy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
+
+ return pil_images
+
+ @staticmethod
+ def rgblike_to_depthmap(image):
+ """
+ Args:
+ image: RGB-like depth image
+
+ Returns: depth map
+
+ """
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
+
+ def numpy_to_depth(self, images):
+ """
+ Convert a NumPy depth image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images_depth = images[:, :, :, 3:]
+ if images.shape[-1] == 6:
+ images_depth = (images_depth * 255).round().astype("uint8")
+ pil_images = [
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
+ ]
+ elif images.shape[-1] == 4:
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
+ else:
+ raise Exception("Not supported")
+
+ return pil_images
+
+ def postprocess(
+ self,
+ image: torch.FloatTensor,
+ output_type: str = "pil",
+ do_denormalize: Optional[List[bool]] = None,
+ ):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
+ )
+ if output_type not in ["latent", "pt", "np", "pil"]:
+ deprecation_message = (
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
+ "`pil`, `np`, `pt`, `latent`"
+ )
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
+ output_type = "np"
+
+ if do_denormalize is None:
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
+
+ image = torch.stack(
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
+ )
+
+ image = self.pt_to_numpy(image)
+
+ if output_type == "np":
+ if image.shape[-1] == 6:
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
+ else:
+ image_depth = image[:, :, :, 3:]
+ return image[:, :, :, :3], image_depth
+
+ if output_type == "pil":
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
+ else:
+ raise Exception(f"This type {output_type} is not supported")
diff --git a/diffusers/loaders.py b/diffusers/loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce5989b5f49504954a9f4157548376a46b3630b
--- /dev/null
+++ b/diffusers/loaders.py
@@ -0,0 +1,1874 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import warnings
+from collections import defaultdict
+from contextlib import nullcontext
+from io import BytesIO
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import requests
+import torch
+import torch.nn.functional as F
+from huggingface_hub import hf_hub_download
+from torch import nn
+
+from .models.attention_processor import (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ AttnProcessor,
+ AttnProcessor2_0,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionXFormersAttnProcessor,
+ LoRAAttnAddedKVProcessor,
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+ LoRALinearLayer,
+ LoRAXFormersAttnProcessor,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnProcessor,
+)
+from .utils import (
+ DIFFUSERS_CACHE,
+ HF_HUB_OFFLINE,
+ _get_model_file,
+ deprecate,
+ is_accelerate_available,
+ is_omegaconf_available,
+ is_safetensors_available,
+ is_transformers_available,
+ logging,
+)
+from .utils.import_utils import BACKENDS_MAPPING
+
+
+if is_safetensors_available():
+ import safetensors
+
+if is_transformers_available():
+ from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+ from accelerate.utils import set_module_tensor_to_device
+
+logger = logging.get_logger(__name__)
+
+TEXT_ENCODER_NAME = "text_encoder"
+UNET_NAME = "unet"
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
+TEXT_INVERSION_NAME = "learned_embeds.bin"
+TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
+
+CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
+CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
+
+
+class PatchedLoraProjection(nn.Module):
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
+ super().__init__()
+ self.regular_linear_layer = regular_linear_layer
+
+ device = self.regular_linear_layer.weight.device
+
+ if dtype is None:
+ dtype = self.regular_linear_layer.weight.dtype
+
+ self.lora_linear_layer = LoRALinearLayer(
+ self.regular_linear_layer.in_features,
+ self.regular_linear_layer.out_features,
+ network_alpha=network_alpha,
+ device=device,
+ dtype=dtype,
+ rank=rank,
+ )
+
+ self.lora_scale = lora_scale
+
+ def forward(self, input):
+ return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
+
+
+def text_encoder_attn_modules(text_encoder):
+ attn_modules = []
+
+ if isinstance(text_encoder, CLIPTextModel):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+ else:
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
+
+ return attn_modules
+
+
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+class AttnProcsLayers(torch.nn.Module):
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
+ super().__init__()
+ self.layers = torch.nn.ModuleList(state_dict.values())
+ self.mapping = dict(enumerate(state_dict.keys()))
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
+
+ # .processor for unet, .self_attn for text encoder
+ self.split_keys = [".processor", ".self_attn"]
+
+ # we add a hook to state_dict() and load_state_dict() so that the
+ # naming fits with `unet.attn_processors`
+ def map_to(module, state_dict, *args, **kwargs):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ num = int(key.split(".")[1]) # 0 is always "layers"
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
+ new_state_dict[new_key] = value
+
+ return new_state_dict
+
+ def remap_key(key, state_dict):
+ for k in self.split_keys:
+ if k in key:
+ return key.split(k)[0] + k
+
+ raise ValueError(
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
+ )
+
+ def map_from(module, state_dict, *args, **kwargs):
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ replace_key = remap_key(key, state_dict)
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+
+ self._register_state_dict_hook(map_to)
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
+
+
+class UNet2DConditionLoadersMixin:
+ text_encoder_name = TEXT_ENCODER_NAME
+ unet_name = UNET_NAME
+
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ r"""
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
+ defined in
+ [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
+ and be a `torch.nn.Module` class.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ """
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ network_alpha = kwargs.pop("network_alpha", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ # fill attn processors
+ attn_processors = {}
+
+ is_lora = all("lora" in k for k in state_dict.keys())
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
+
+ if is_lora:
+ is_new_lora_format = all(
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
+ )
+ if is_new_lora_format:
+ # Strip the `"unet"` prefix.
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
+ if is_text_encoder_present:
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
+ warnings.warn(warn_message)
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
+
+ lora_grouped_dict = defaultdict(dict)
+ for key, value in state_dict.items():
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+ lora_grouped_dict[attn_processor_key][sub_key] = value
+
+ for key, value_dict in lora_grouped_dict.items():
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
+
+ attn_processor = self
+ for sub_key in key.split("."):
+ attn_processor = getattr(attn_processor, sub_key)
+
+ if isinstance(
+ attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
+ ):
+ cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
+ attn_processor_class = LoRAAttnAddedKVProcessor
+ else:
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
+ if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
+ attn_processor_class = LoRAXFormersAttnProcessor
+ else:
+ attn_processor_class = (
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+ )
+
+ attn_processors[key] = attn_processor_class(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=rank,
+ network_alpha=network_alpha,
+ )
+ attn_processors[key].load_state_dict(value_dict)
+ elif is_custom_diffusion:
+ custom_diffusion_grouped_dict = defaultdict(dict)
+ for key, value in state_dict.items():
+ if len(value) == 0:
+ custom_diffusion_grouped_dict[key] = {}
+ else:
+ if "to_out" in key:
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+ else:
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
+
+ for key, value_dict in custom_diffusion_grouped_dict.items():
+ if len(value_dict) == 0:
+ attn_processors[key] = CustomDiffusionAttnProcessor(
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
+ )
+ else:
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
+ attn_processors[key] = CustomDiffusionAttnProcessor(
+ train_kv=True,
+ train_q_out=train_q_out,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ )
+ attn_processors[key].load_state_dict(value_dict)
+ else:
+ raise ValueError(
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
+ )
+
+ # set correct dtype & device
+ attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
+
+ # set layers
+ self.set_attn_processor(attn_processors)
+
+ def save_attn_procs(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ **kwargs,
+ ):
+ r"""
+ Save an attention processor to a directory so that it can be reloaded using the
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save an attention processor to. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+
+ """
+ weight_name = weight_name or deprecate(
+ "weights_name",
+ "0.20.0",
+ "`weights_name` is deprecated, please use `weight_name` instead.",
+ take_from=kwargs,
+ )
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ is_custom_diffusion = any(
+ isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
+ for (_, x) in self.attn_processors.items()
+ )
+ if is_custom_diffusion:
+ model_to_save = AttnProcsLayers(
+ {
+ y: x
+ for (y, x) in self.attn_processors.items()
+ if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
+ }
+ )
+ state_dict = model_to_save.state_dict()
+ for name, attn in self.attn_processors.items():
+ if len(attn.state_dict()) == 0:
+ state_dict[name] = {}
+ else:
+ model_to_save = AttnProcsLayers(self.attn_processors)
+ state_dict = model_to_save.state_dict()
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weight_name))
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+
+class TextualInversionLoaderMixin:
+ r"""
+ Load textual inversion tokens and embeddings to the tokenizer and text encoder.
+ """
+
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
+ r"""
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
+
+ Parameters:
+ prompt (`str` or list of `str`):
+ The prompt or prompts to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str` or list of `str`: The converted prompt
+ """
+ if not isinstance(prompt, List):
+ prompts = [prompt]
+ else:
+ prompts = prompt
+
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
+
+ if not isinstance(prompt, List):
+ return prompts[0]
+
+ return prompts
+
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
+ r"""
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
+
+ Parameters:
+ prompt (`str`):
+ The prompt to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str`: The converted prompt
+ """
+ tokens = tokenizer.tokenize(prompt)
+ unique_tokens = set(tokens)
+ for token in unique_tokens:
+ if token in tokenizer.added_tokens_encoder:
+ replacement = token
+ i = 1
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
+ replacement += f" {token}_{i}"
+ i += 1
+
+ prompt = prompt.replace(token, replacement)
+
+ return prompt
+
+ def load_textual_inversion(
+ self,
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
+ token: Optional[Union[str, List[str]]] = None,
+ **kwargs,
+ ):
+ r"""
+ Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
+ Automatic1111 formats are supported).
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
+ Can be either one of the following or a list of them:
+
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
+ pretrained model hosted on the Hub.
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
+ inversion weights.
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ token (`str` or `List[str]`, *optional*):
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
+ list, then `token` must also be a list of equal length.
+ weight_name (`str`, *optional*):
+ Name of a custom weight file. This should be used when:
+
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
+ name such as `text_inv.bin`.
+ - The saved textual inversion file is in the Automatic1111 format.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ Example:
+
+ To load a textual inversion embedding vector in 🤗 Diffusers format:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+ import torch
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
+
+ prompt = "A backpack"
+
+ image = pipe(prompt, num_inference_steps=50).images[0]
+ image.save("cat-backpack.png")
+ ```
+
+ To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
+ locally:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+ import torch
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
+
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
+
+ image = pipe(prompt, num_inference_steps=50).images[0]
+ image.save("character.png")
+ ```
+
+ """
+ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "text_inversion",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path, list):
+ pretrained_model_name_or_paths = [pretrained_model_name_or_path]
+ else:
+ pretrained_model_name_or_paths = pretrained_model_name_or_path
+
+ if isinstance(token, str):
+ tokens = [token]
+ elif token is None:
+ tokens = [None] * len(pretrained_model_name_or_paths)
+ else:
+ tokens = token
+
+ if len(pretrained_model_name_or_paths) != len(tokens):
+ raise ValueError(
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
+ f"Make sure both lists have the same length."
+ )
+
+ valid_tokens = [t for t in tokens if t is not None]
+ if len(set(valid_tokens)) < len(valid_tokens):
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
+
+ token_ids_and_embeddings = []
+
+ for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
+ if not isinstance(pretrained_model_name_or_path, dict):
+ # 1. Load textual inversion file
+ model_file = None
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except Exception as e:
+ if not allow_pickle:
+ raise e
+
+ model_file = None
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path
+
+ # 2. Load token and embedding correcly from file
+ loaded_token = None
+ if isinstance(state_dict, torch.Tensor):
+ if token is None:
+ raise ValueError(
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
+ )
+ embedding = state_dict
+ elif len(state_dict) == 1:
+ # diffusers
+ loaded_token, embedding = next(iter(state_dict.items()))
+ elif "string_to_param" in state_dict:
+ # A1111
+ loaded_token = state_dict["name"]
+ embedding = state_dict["string_to_param"]["*"]
+
+ if token is not None and loaded_token != token:
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
+ else:
+ token = loaded_token
+
+ embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
+
+ # 3. Make sure we don't mess up the tokenizer or text encoder
+ vocab = self.tokenizer.get_vocab()
+ if token in vocab:
+ raise ValueError(
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
+ )
+ elif f"{token}_1" in vocab:
+ multi_vector_tokens = [token]
+ i = 1
+ while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
+ multi_vector_tokens.append(f"{token}_{i}")
+ i += 1
+
+ raise ValueError(
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
+ )
+
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
+
+ if is_multi_vector:
+ tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
+ embeddings = [e for e in embedding] # noqa: C416
+ else:
+ tokens = [token]
+ embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
+
+ # add tokens and get ids
+ self.tokenizer.add_tokens(tokens)
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ token_ids_and_embeddings += zip(token_ids, embeddings)
+
+ logger.info(f"Loaded textual inversion embedding for {token}.")
+
+ # resize token embeddings and set all new embeddings
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
+ for token_id, embedding in token_ids_and_embeddings:
+ self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
+
+
+class LoraLoaderMixin:
+ r"""
+ Load LoRA layers into [`UNet2DConditionModel`] and
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+ """
+ text_encoder_name = TEXT_ENCODER_NAME
+ unet_name = UNET_NAME
+
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
+ `self.unet`.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
+ into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+
+ kwargs:
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ """
+ state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
+ self.load_lora_into_text_encoder(
+ state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
+ )
+
+ @classmethod
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except (IOError, safetensors.SafetensorError) as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
+ network_alpha = None
+ if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
+ state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alpha
+
+ @classmethod
+ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alpha (`float`):
+ See `LoRALinearLayer` for more details.
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ """
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
+ # Load the layers corresponding to UNet.
+ unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
+ logger.info(f"Loading {cls.unet_name}.")
+ unet_lora_state_dict = {
+ k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
+ }
+ unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
+
+ # Otherwise, we're dealing with the old format. This means the `state_dict` should only
+ # contain the module names of the `unet` as its keys WITHOUT any prefix.
+ elif not all(
+ key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
+ ):
+ unet.load_attn_procs(state_dict)
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
+ warnings.warn(warn_message)
+
+ @classmethod
+ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alpha (`float`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ """
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {cls.text_encoder_name}.")
+
+ if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
+ # Convert from the old naming convention to the new naming convention.
+ #
+ # Previously, the old LoRA layers were stored on the state dict at the
+ # same level as the attention block i.e.
+ # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
+ #
+ # This is no actual module at that point, they were monkey patched on to the
+ # existing module. We want to be able to load them via their actual state dict.
+ # They're in `PatchedLoraProjection.lora_linear_layer` now.
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ text_encoder_lora_state_dict[
+ f"{name}.q_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.k_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.v_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.out_proj.lora_linear_layer.up.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
+
+ text_encoder_lora_state_dict[
+ f"{name}.q_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.k_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.v_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
+ text_encoder_lora_state_dict[
+ f"{name}.out_proj.lora_linear_layer.down.weight"
+ ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
+
+ rank = text_encoder_lora_state_dict[
+ "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
+ ].shape[1]
+
+ cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
+
+ # set correct dtype & device
+ text_encoder_lora_state_dict = {
+ k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
+ for k, v in text_encoder_lora_state_dict.items()
+ }
+
+ load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
+ if len(load_state_dict_results.unexpected_keys) != 0:
+ raise ValueError(
+ f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
+ )
+
+ @property
+ def lora_scale(self) -> float:
+ # property function that returns the lora scale which can be set at run time by the pipeline.
+ # if _lora_scale has not been set, return 1
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+
+ @classmethod
+ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj = attn_module.q_proj.regular_linear_layer
+ attn_module.k_proj = attn_module.k_proj.regular_linear_layer
+ attn_module.v_proj = attn_module.v_proj.regular_linear_layer
+ attn_module.out_proj = attn_module.out_proj.regular_linear_layer
+
+ @classmethod
+ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
+ r"""
+ Monkey-patches the forward passes of attention modules of the text encoder.
+ """
+
+ # First, remove any monkey-patch that might have been applied before
+ cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
+
+ lora_parameters = []
+
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ attn_module.q_proj = PatchedLoraProjection(
+ attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
+
+ attn_module.k_proj = PatchedLoraProjection(
+ attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
+
+ attn_module.v_proj = PatchedLoraProjection(
+ attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
+
+ attn_module.out_proj = PatchedLoraProjection(
+ attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
+ )
+ lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
+
+ return lora_parameters
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the UNet.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # Create a flat dictionary.
+ state_dict = {}
+ if unet_lora_layers is not None:
+ weights = (
+ unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
+ )
+
+ unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
+ state_dict.update(unet_lora_state_dict)
+
+ if text_encoder_lora_layers is not None:
+ weights = (
+ text_encoder_lora_layers.state_dict()
+ if isinstance(text_encoder_lora_layers, torch.nn.Module)
+ else text_encoder_lora_layers
+ )
+
+ text_encoder_lora_state_dict = {
+ f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
+ }
+ state_dict.update(text_encoder_lora_state_dict)
+
+ # Save the model
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = LORA_WEIGHT_NAME
+
+ save_function(state_dict, os.path.join(save_directory, weight_name))
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+ @classmethod
+ def _convert_kohya_lora_to_diffusers(cls, state_dict):
+ unet_state_dict = {}
+ te_state_dict = {}
+ network_alpha = None
+
+ for key, value in state_dict.items():
+ if "lora_down" in key:
+ lora_name = key.split(".")[0]
+ lora_name_up = lora_name + ".lora_up.weight"
+ lora_name_alpha = lora_name + ".alpha"
+ if lora_name_alpha in state_dict:
+ alpha = state_dict[lora_name_alpha].item()
+ if network_alpha is None:
+ network_alpha = alpha
+ elif network_alpha != alpha:
+ raise ValueError("Network alpha is not consistent")
+
+ if lora_name.startswith("lora_unet_"):
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
+ if "transformer_blocks" in diffusers_name:
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
+ unet_state_dict[diffusers_name] = value
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
+ elif lora_name.startswith("lora_te_"):
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = value
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
+
+ unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
+ te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
+ new_state_dict = {**unet_state_dict, **te_state_dict}
+ return new_state_dict, network_alpha
+
+ def unload_lora_weights(self):
+ """
+ Unloads the LoRA parameters.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ is_unet_lora = all(
+ isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
+ for _, processor in self.unet.attn_processors.items()
+ )
+ # Handle attention processors that are a mix of regular attention and AddedKV
+ # attention.
+ if is_unet_lora:
+ is_attn_procs_mixed = all(
+ isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
+ for _, processor in self.unet.attn_processors.items()
+ )
+ if not is_attn_procs_mixed:
+ unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
+ self.unet.set_attn_processor(unet_attn_proc_cls())
+ else:
+ self.unet.set_default_attn_processor()
+
+ # Safe to call the following regardless of LoRA.
+ self._remove_text_encoder_monkey_patch()
+
+
+class FromSingleFileMixin:
+ """
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
+ """
+
+ @classmethod
+ def from_ckpt(cls, *args, **kwargs):
+ deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
+ deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
+ return cls.from_single_file(*args, **kwargs)
+
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ extract_ema (`bool`, *optional*, defaults to `False`):
+ Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
+ higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ prediction_type (`str`, *optional*):
+ The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
+ the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
+ num_in_channels (`int`, *optional*, defaults to `None`):
+ The number of input channels. If `None`, it will be automatically inferred.
+ scheduler_type (`str`, *optional*, defaults to `"pndm"`):
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
+ "ddim"]`.
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether to load the safety checker or not.
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
+ An instance of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
+ needed.
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
+ An instance of
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
+ itself, if needed.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+ ... )
+
+ >>> # Download pipeline from local file
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
+
+ >>> # Enable float16 and move to GPU
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipeline.to("cuda")
+ ```
+ """
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ extract_ema = kwargs.pop("extract_ema", False)
+ image_size = kwargs.pop("image_size", None)
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
+ num_in_channels = kwargs.pop("num_in_channels", None)
+ upcast_attention = kwargs.pop("upcast_attention", None)
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
+ prediction_type = kwargs.pop("prediction_type", None)
+ text_encoder = kwargs.pop("text_encoder", None)
+ controlnet = kwargs.pop("controlnet", None)
+ tokenizer = kwargs.pop("tokenizer", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
+
+ pipeline_name = cls.__name__
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # TODO: For now we only support stable diffusion
+ stable_unclip = None
+ model_type = None
+
+ if pipeline_name in [
+ "StableDiffusionControlNetPipeline",
+ "StableDiffusionControlNetImg2ImgPipeline",
+ "StableDiffusionControlNetInpaintPipeline",
+ ]:
+ from .models.controlnet import ControlNetModel
+ from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
+
+ # Model type will be inferred from the checkpoint.
+ if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
+ raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
+ elif "StableDiffusion" in pipeline_name:
+ # Model type will be inferred from the checkpoint.
+ pass
+ elif pipeline_name == "StableUnCLIPPipeline":
+ model_type = "FrozenOpenCLIPEmbedder"
+ stable_unclip = "txt2img"
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
+ model_type = "FrozenOpenCLIPEmbedder"
+ stable_unclip = "img2img"
+ elif pipeline_name == "PaintByExamplePipeline":
+ model_type = "PaintByExample"
+ elif pipeline_name == "LDMTextToImagePipeline":
+ model_type = "LDMTextToImage"
+ else:
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
+
+ # remove huggingface url
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = os.path.join(*ckpt_path.parts[:2])
+ file_path = os.path.join(*ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ pipe = download_from_original_stable_diffusion_ckpt(
+ pretrained_model_link_or_path,
+ pipeline_class=cls,
+ model_type=model_type,
+ stable_unclip=stable_unclip,
+ controlnet=controlnet,
+ from_safetensors=from_safetensors,
+ extract_ema=extract_ema,
+ image_size=image_size,
+ scheduler_type=scheduler_type,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ load_safety_checker=load_safety_checker,
+ prediction_type=prediction_type,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ if torch_dtype is not None:
+ pipe.to(torch_dtype=torch_dtype)
+
+ return pipe
+
+
+class FromOriginalVAEMixin:
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or
+ `.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by
+ default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
+ = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
+ Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load
+ a VAE that does accompany a stable diffusion model of v2 or higher or SDXL.
+
+
+
+ Examples:
+
+ ```py
+ from diffusers import AutoencoderKL
+
+ url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
+ model = AutoencoderKL.from_single_file(url)
+ ```
+ """
+ if not is_omegaconf_available():
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
+
+ from omegaconf import OmegaConf
+
+ from .models import AutoencoderKL
+
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import (
+ convert_ldm_vae_checkpoint,
+ create_vae_diffusers_config,
+ )
+
+ config_file = kwargs.pop("config_file", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ image_size = kwargs.pop("image_size", None)
+ scaling_factor = kwargs.pop("scaling_factor", None)
+ kwargs.pop("upcast_attention", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
+
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # remove huggingface url
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = "/".join(ckpt_path.parts[:2])
+ file_path = "/".join(ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ if from_safetensors:
+ from safetensors import safe_open
+
+ checkpoint = {}
+ with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ else:
+ checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
+
+ if "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ if config_file is None:
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
+ config_file = BytesIO(requests.get(config_url).content)
+
+ original_config = OmegaConf.load(config_file)
+
+ # default to sd-v1-5
+ image_size = image_size or 512
+
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
+
+ if scaling_factor is None:
+ if (
+ "model" in original_config
+ and "params" in original_config.model
+ and "scale_factor" in original_config.model.params
+ ):
+ vae_scaling_factor = original_config.model.params.scale_factor
+ else:
+ vae_scaling_factor = 0.18215 # default SD scaling factor
+
+ vae_config["scaling_factor"] = vae_scaling_factor
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ vae = AutoencoderKL(**vae_config)
+
+ if is_accelerate_available():
+ for param_name, param in converted_vae_checkpoint.items():
+ set_module_tensor_to_device(vae, param_name, "cpu", value=param)
+ else:
+ vae.load_state_dict(converted_vae_checkpoint)
+
+ if torch_dtype is not None:
+ vae.to(torch_dtype=torch_dtype)
+
+ return vae
+
+
+class FromOriginalControlnetMixin:
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or
+ `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+ Examples:
+
+ ```py
+ from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
+
+ url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
+ model = ControlNetModel.from_single_file(url)
+
+ url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
+ pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
+ ```
+ """
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
+
+ config_file = kwargs.pop("config_file", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ num_in_channels = kwargs.pop("num_in_channels", None)
+ use_linear_projection = kwargs.pop("use_linear_projection", None)
+ revision = kwargs.pop("revision", None)
+ extract_ema = kwargs.pop("extract_ema", False)
+ image_size = kwargs.pop("image_size", None)
+ upcast_attention = kwargs.pop("upcast_attention", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
+
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # remove huggingface url
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = "/".join(ckpt_path.parts[:2])
+ file_path = "/".join(ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ if config_file is None:
+ config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
+ config_file = BytesIO(requests.get(config_url).content)
+
+ image_size = image_size or 512
+
+ controlnet = download_controlnet_from_original_ckpt(
+ pretrained_model_link_or_path,
+ original_config_file=config_file,
+ image_size=image_size,
+ extract_ema=extract_ema,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ from_safetensors=from_safetensors,
+ use_linear_projection=use_linear_projection,
+ )
+
+ if torch_dtype is not None:
+ controlnet.to(torch_dtype=torch_dtype)
+
+ return controlnet
diff --git a/diffusers/models/README.md b/diffusers/models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fb91f59411265660e01d8b4bcc0b99e8b8fe9d55
--- /dev/null
+++ b/diffusers/models/README.md
@@ -0,0 +1,3 @@
+# Models
+
+For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
\ No newline at end of file
diff --git a/diffusers/models/__init__.py b/diffusers/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e330a44691a73f51d109de2e268d179e9e86d87
--- /dev/null
+++ b/diffusers/models/__init__.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_flax_available, is_torch_available
+
+
+if is_torch_available():
+ from .adapter import MultiAdapter, T2IAdapter
+ from .autoencoder_kl import AutoencoderKL
+ from .controlnet import ControlNetModel
+ from .dual_transformer_2d import DualTransformer2DModel
+ from .modeling_utils import ModelMixin
+ from .prior_transformer import PriorTransformer
+ from .t5_film_transformer import T5FilmDecoder
+ from .transformer_2d import Transformer2DModel
+ from .unet_1d import UNet1DModel
+ from .unet_2d import UNet2DModel
+ from .unet_2d_condition import UNet2DConditionModel
+ from .unet_3d_condition import UNet3DConditionModel
+ from .vq_model import VQModel
+
+if is_flax_available():
+ from .controlnet_flax import FlaxControlNetModel
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
+ from .vae_flax import FlaxAutoencoderKL
diff --git a/diffusers/models/__pycache__/__init__.cpython-310.pyc b/diffusers/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d41f8708cf72411f429c0152a8540dd52f8f7366
Binary files /dev/null and b/diffusers/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/__init__.cpython-38.pyc b/diffusers/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7d85c887654ef4ce8dfecee355c598a812da002
Binary files /dev/null and b/diffusers/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/activations.cpython-310.pyc b/diffusers/models/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..430c21f17aff487f040247c33b564d7c1d6ada25
Binary files /dev/null and b/diffusers/models/__pycache__/activations.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/activations.cpython-38.pyc b/diffusers/models/__pycache__/activations.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c8a513da41b63b51c355044df7b93c3269dceda
Binary files /dev/null and b/diffusers/models/__pycache__/activations.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/adapter.cpython-310.pyc b/diffusers/models/__pycache__/adapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d743cc5a8af08b4b44adb636ecdca658881ad37d
Binary files /dev/null and b/diffusers/models/__pycache__/adapter.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/adapter.cpython-38.pyc b/diffusers/models/__pycache__/adapter.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59fdd45dde4bd0bb24016ed510edc339b2620084
Binary files /dev/null and b/diffusers/models/__pycache__/adapter.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/attention.cpython-310.pyc b/diffusers/models/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70cfc42e993409e73081e637a2047c088a9f4e01
Binary files /dev/null and b/diffusers/models/__pycache__/attention.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/attention.cpython-38.pyc b/diffusers/models/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa7699c170b1a9c2672ce06edd6e3747356a3288
Binary files /dev/null and b/diffusers/models/__pycache__/attention.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/attention_processor.cpython-310.pyc b/diffusers/models/__pycache__/attention_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e9ed21e0de471a8a8d08b8c5271707b55bdafa8
Binary files /dev/null and b/diffusers/models/__pycache__/attention_processor.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/attention_processor.cpython-38.pyc b/diffusers/models/__pycache__/attention_processor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ceeb4fb3d80d220b9978f137d72c43d005b092c
Binary files /dev/null and b/diffusers/models/__pycache__/attention_processor.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc b/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b3b7362866f45f2d9a42bd42f7b9e226212d024
Binary files /dev/null and b/diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc b/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bcaa7b767796f38b2e5e3cd61eb09a2d19a2c24
Binary files /dev/null and b/diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/controlnet.cpython-310.pyc b/diffusers/models/__pycache__/controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe07158e28a0e8977c6fe6cb356831025dcebb3b
Binary files /dev/null and b/diffusers/models/__pycache__/controlnet.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/controlnet.cpython-38.pyc b/diffusers/models/__pycache__/controlnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4ae9403fa341fa5121c0b00f96cc62d6418aa61
Binary files /dev/null and b/diffusers/models/__pycache__/controlnet.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc b/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f26ae29eb6d8e6e335d47e27813cec8fbe267c3
Binary files /dev/null and b/diffusers/models/__pycache__/controlnet_composer.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc b/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c353b814edfb02522241318441c65223d8a2916
Binary files /dev/null and b/diffusers/models/__pycache__/controlnet_composer.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc b/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf9b8d3972723e0ebeca90e6344b1b63a4fbb9aa
Binary files /dev/null and b/diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc b/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7755b8183aebb795eb674c34c50fbb46baf41916
Binary files /dev/null and b/diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/embeddings.cpython-310.pyc b/diffusers/models/__pycache__/embeddings.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d82a2d6acf832243e90144e2379fc51a2ebd46f
Binary files /dev/null and b/diffusers/models/__pycache__/embeddings.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/embeddings.cpython-38.pyc b/diffusers/models/__pycache__/embeddings.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f93c7669346644de0de8adc993b83afd9cbe1a21
Binary files /dev/null and b/diffusers/models/__pycache__/embeddings.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc b/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bc965ae71474dea4a451335a7187a988c483eba
Binary files /dev/null and b/diffusers/models/__pycache__/modeling_utils.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc b/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5490aa849fd18b500efba6fec7c63d4313ef7c1
Binary files /dev/null and b/diffusers/models/__pycache__/modeling_utils.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc b/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b32127351caa0cf733f34629da9e1a80aaab280
Binary files /dev/null and b/diffusers/models/__pycache__/prior_transformer.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc b/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3018b1dfa7e673a3e112583749bf2557bda6cd9e
Binary files /dev/null and b/diffusers/models/__pycache__/prior_transformer.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/resnet.cpython-310.pyc b/diffusers/models/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63bb6d768605db91ae1c025e395ea6729b701284
Binary files /dev/null and b/diffusers/models/__pycache__/resnet.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/resnet.cpython-38.pyc b/diffusers/models/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5c143d3e5443ab82f60e0b8a71130794c7e777e
Binary files /dev/null and b/diffusers/models/__pycache__/resnet.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc b/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d48a732862a603a63551ee63f562dbb4c60ce98c
Binary files /dev/null and b/diffusers/models/__pycache__/t5_film_transformer.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc b/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17fc2eb047ebb67d14d13a635ed975bcc5c2d9a2
Binary files /dev/null and b/diffusers/models/__pycache__/t5_film_transformer.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc b/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67b09e58232c6297a886e2f50901c5a2be045ccf
Binary files /dev/null and b/diffusers/models/__pycache__/transformer_2d.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc b/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f3784d96ccdcaddfb777906a28e8417ffe23427
Binary files /dev/null and b/diffusers/models/__pycache__/transformer_2d.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc b/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75f463b8e2707cc17655b01e758f624bcc375f32
Binary files /dev/null and b/diffusers/models/__pycache__/transformer_temporal.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc b/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41796f0f1898457ffcd63a5c587d15c734b7bd4d
Binary files /dev/null and b/diffusers/models/__pycache__/transformer_temporal.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_1d.cpython-310.pyc b/diffusers/models/__pycache__/unet_1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d82d8fdaa98c571329579c6e0ce029044b7054a4
Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_1d.cpython-38.pyc b/diffusers/models/__pycache__/unet_1d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a05ba3ad92d7a1a774c102b5a5859be40f03cd6
Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db61b43437e3f1376f165c6c06e30508ad3e3c82
Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d_blocks.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02206aaf19e703a6f0be9781b70fb12a0647e83c
Binary files /dev/null and b/diffusers/models/__pycache__/unet_1d_blocks.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a563a6f94ef8ff56ada1bb45e11957e07db9423f
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a39afdf290f9a560a8fa9f71e4b046046c05f85e
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0ef358590c1b02d291eccad421d9e193514a018
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_blocks.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bfe3515ccad52b5b724374afb7f9475fbf04877
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_blocks.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc60cef1f039a325a9d6fef9dbb32690c1cbc67a
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea411acd7fab6fcdde4ff9cda0c0f0a729be1fb1
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1283f9021f1fc7660bd666fd6612dada0fa003d
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca9aab4d96b39d8b54d160152ce568b534c8e0d7
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f323b6876388e517b9902d45db6f670252728ac
Binary files /dev/null and b/diffusers/models/__pycache__/unet_2d_condition_multi_branch_downup.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc b/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04fdec397efe9eab97c56e7f62828cbbdac13c35
Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_blocks.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc b/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e14b6d8e3fc156e8bb4b03f14254e33c6bd06cf
Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_blocks.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc b/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0744cf16bd4025d76da50a32f19651c2748f21d4
Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_condition.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc b/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdb9505e5e652469d8810beef3ee8b366a98143d
Binary files /dev/null and b/diffusers/models/__pycache__/unet_3d_condition.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/vae.cpython-310.pyc b/diffusers/models/__pycache__/vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d880c6523df5a735260dab5fd9f578c76334f0a4
Binary files /dev/null and b/diffusers/models/__pycache__/vae.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/vae.cpython-38.pyc b/diffusers/models/__pycache__/vae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c9d73944fae981cc7fd4d1301b18edd6853207
Binary files /dev/null and b/diffusers/models/__pycache__/vae.cpython-38.pyc differ
diff --git a/diffusers/models/__pycache__/vq_model.cpython-310.pyc b/diffusers/models/__pycache__/vq_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de5463e3803e77fc9e78b1fa22ab9bc8f363a87f
Binary files /dev/null and b/diffusers/models/__pycache__/vq_model.cpython-310.pyc differ
diff --git a/diffusers/models/__pycache__/vq_model.cpython-38.pyc b/diffusers/models/__pycache__/vq_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4532f2e1c2b4d1b223b408c2ad81c09e0bee161c
Binary files /dev/null and b/diffusers/models/__pycache__/vq_model.cpython-38.pyc differ
diff --git a/diffusers/models/activations.py b/diffusers/models/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..64759b706e2f108803e51ccd50f9dff67ad49722
--- /dev/null
+++ b/diffusers/models/activations.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+def get_activation(act_fn):
+ if act_fn in ["swish", "silu"]:
+ return nn.SiLU()
+ elif act_fn == "mish":
+ return nn.Mish()
+ elif act_fn == "gelu":
+ return nn.GELU()
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
diff --git a/diffusers/models/adapter.py b/diffusers/models/adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a65a3873b130a4b74b46ebfb34b99067ee1a6a6e
--- /dev/null
+++ b/diffusers/models/adapter.py
@@ -0,0 +1,291 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .modeling_utils import ModelMixin
+from .resnet import Downsample2D
+
+
+class MultiAdapter(ModelMixin):
+ r"""
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
+ user-assigned weighting.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
+ A list of `T2IAdapter` model instances.
+ """
+
+ def __init__(self, adapters: List["T2IAdapter"]):
+ super(MultiAdapter, self).__init__()
+
+ self.num_adapter = len(adapters)
+ self.adapters = nn.ModuleList(adapters)
+
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
+ r"""
+ Args:
+ xs (`torch.Tensor`):
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
+ `channel` should equal to `num_adapter` * "number of channel of image".
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding
+ them together.
+ """
+ if adapter_weights is None:
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
+ else:
+ adapter_weights = torch.tensor(adapter_weights)
+
+ if xs.shape[1] % self.num_adapter != 0:
+ raise ValueError(
+ f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
+ f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
+ )
+ x_list = torch.chunk(xs, self.num_adapter, dim=1)
+ accume_state = None
+ for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
+ features = adapter(x)
+ if accume_state is None:
+ accume_state = features
+ else:
+ for i in range(len(features)):
+ accume_state[i] += w * features[i]
+ return accume_state
+
+
+class T2IAdapter(ModelMixin, ConfigMixin):
+ r"""
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
+ architecture follows the original implementation of
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
+ and
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
+ image as *control image*.
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
+ also determine the number of downsample blocks in the Adapter.
+ num_res_blocks (`int`, *optional*, defaults to 2):
+ Number of ResNet blocks in each downsample block
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280, 1280],
+ num_res_blocks: int = 2,
+ downscale_factor: int = 8,
+ adapter_type: str = "full_adapter",
+ ):
+ super().__init__()
+
+ if adapter_type == "full_adapter":
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
+ elif adapter_type == "light_adapter":
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
+ else:
+ raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ return self.adapter(x)
+
+ @property
+ def total_downscale_factor(self):
+ return self.adapter.total_downscale_factor
+
+
+# full adapter
+
+
+class FullAdapter(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280, 1280],
+ num_res_blocks: int = 2,
+ downscale_factor: int = 8,
+ ):
+ super().__init__()
+
+ in_channels = in_channels * downscale_factor**2
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
+
+ self.body = nn.ModuleList(
+ [
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
+ *[
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
+ for i in range(1, len(channels))
+ ],
+ ]
+ )
+
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ x = self.unshuffle(x)
+ x = self.conv_in(x)
+
+ features = []
+
+ for block in self.body:
+ x = block(x)
+ features.append(x)
+
+ return features
+
+
+class AdapterBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
+ super().__init__()
+
+ self.downsample = None
+ if down:
+ self.downsample = Downsample2D(in_channels)
+
+ self.in_conv = None
+ if in_channels != out_channels:
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+
+ self.resnets = nn.Sequential(
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
+ )
+
+ def forward(self, x):
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ if self.in_conv is not None:
+ x = self.in_conv(x)
+
+ x = self.resnets(x)
+
+ return x
+
+
+class AdapterResnetBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
+
+ def forward(self, x):
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
+
+
+# light adapter
+
+
+class LightAdapter(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280],
+ num_res_blocks: int = 4,
+ downscale_factor: int = 8,
+ ):
+ super().__init__()
+
+ in_channels = in_channels * downscale_factor**2
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
+
+ self.body = nn.ModuleList(
+ [
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
+ *[
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
+ for i in range(len(channels) - 1)
+ ],
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
+ ]
+ )
+
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
+
+ def forward(self, x):
+ x = self.unshuffle(x)
+
+ features = []
+
+ for block in self.body:
+ x = block(x)
+ features.append(x)
+
+ return features
+
+
+class LightAdapterBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
+ super().__init__()
+ mid_channels = out_channels // 4
+
+ self.downsample = None
+ if down:
+ self.downsample = Downsample2D(in_channels)
+
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ x = self.in_conv(x)
+ x = self.resnets(x)
+ x = self.out_conv(x)
+
+ return x
+
+
+class LightAdapterResnetBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
diff --git a/diffusers/models/attention.py b/diffusers/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b05bf35e87f503df3e265bd587d1ca3f32f2bc5
--- /dev/null
+++ b/diffusers/models/attention.py
@@ -0,0 +1,389 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..utils import maybe_allow_in_graph
+from .activations import get_activation
+from .attention_processor import Attention
+from .embeddings import CombinedTimestepLabelEmbeddings
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 1. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+ self.approximate = approximate
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ """
+ The approximate form of Gaussian Error Linear Unit (GELU)
+
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Norm layer modified to incorporate timestep embeddings.
+ """
+
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class AdaLayerNormZero(nn.Module):
+ """
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+ """
+
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaGroupNorm(nn.Module):
+ """
+ GroupNorm layer modified to incorporate timestep embeddings.
+ """
+
+ def __init__(
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x, emb):
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
diff --git a/diffusers/models/attention_flax.py b/diffusers/models/attention_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b160d2384311c1fb426b87c11e5fa1572584070
--- /dev/null
+++ b/diffusers/models/attention_flax.py
@@ -0,0 +1,446 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import math
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+
+def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
+ """Multi-head dot product attention with a limited number of queries."""
+ num_kv, num_heads, k_features = key.shape[-3:]
+ v_features = value.shape[-1]
+ key_chunk_size = min(key_chunk_size, num_kv)
+ query = query / jnp.sqrt(k_features)
+
+ @functools.partial(jax.checkpoint, prevent_cse=False)
+ def summarize_chunk(query, key, value):
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
+
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
+ max_score = jax.lax.stop_gradient(max_score)
+ exp_weights = jnp.exp(attn_weights - max_score)
+
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
+ max_score = jnp.einsum("...qhk->...qh", max_score)
+
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
+
+ def chunk_scanner(chunk_idx):
+ # julienne key array
+ key_chunk = jax.lax.dynamic_slice(
+ operand=key,
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
+ )
+
+ # julienne value array
+ value_chunk = jax.lax.dynamic_slice(
+ operand=value,
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
+ )
+
+ return summarize_chunk(query, key_chunk, value_chunk)
+
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
+
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
+ max_diffs = jnp.exp(chunk_max - global_max)
+
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
+ chunk_weights *= max_diffs
+
+ all_values = chunk_values.sum(axis=0)
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
+
+ return all_values / all_weights
+
+
+def jax_memory_efficient_attention(
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
+):
+ r"""
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
+ https://github.com/AminRezaei0x443/memory-efficient-attention
+
+ Args:
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
+ numerical precision for computation
+ query_chunk_size (`int`, *optional*, defaults to 1024):
+ chunk size to divide query array value must divide query_length equally without remainder
+ key_chunk_size (`int`, *optional*, defaults to 4096):
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
+
+ Returns:
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
+ """
+ num_q, num_heads, q_features = query.shape[-3:]
+
+ def chunk_scanner(chunk_idx, _):
+ # julienne query array
+ query_chunk = jax.lax.dynamic_slice(
+ operand=query,
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
+ )
+
+ return (
+ chunk_idx + query_chunk_size, # unused ignore it
+ _query_chunk_attention(
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
+ ),
+ )
+
+ _, res = jax.lax.scan(
+ f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
+ )
+
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
+
+
+class FlaxAttention(nn.Module):
+ r"""
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
+
+ Parameters:
+ query_dim (:obj:`int`):
+ Input hidden states dimension
+ heads (:obj:`int`, *optional*, defaults to 8):
+ Number of heads
+ dim_head (:obj:`int`, *optional*, defaults to 64):
+ Hidden states dimension inside each head
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+
+ """
+ query_dim: int
+ heads: int = 8
+ dim_head: int = 64
+ dropout: float = 0.0
+ use_memory_efficient_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ inner_dim = self.dim_head * self.heads
+ self.scale = self.dim_head**-0.5
+
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
+
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def __call__(self, hidden_states, context=None, deterministic=True):
+ context = hidden_states if context is None else context
+
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(context)
+ value_proj = self.value(context)
+
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
+
+ if self.use_memory_efficient_attention:
+ query_states = query_states.transpose(1, 0, 2)
+ key_states = key_states.transpose(1, 0, 2)
+ value_states = value_states.transpose(1, 0, 2)
+
+ # this if statement create a chunk size for each layer of the unet
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
+
+ flatten_latent_dim = query_states.shape[-3]
+ if flatten_latent_dim % 64 == 0:
+ query_chunk_size = int(flatten_latent_dim / 64)
+ elif flatten_latent_dim % 16 == 0:
+ query_chunk_size = int(flatten_latent_dim / 16)
+ elif flatten_latent_dim % 4 == 0:
+ query_chunk_size = int(flatten_latent_dim / 4)
+ else:
+ query_chunk_size = int(flatten_latent_dim)
+
+ hidden_states = jax_memory_efficient_attention(
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
+ )
+
+ hidden_states = hidden_states.transpose(1, 0, 2)
+ else:
+ # compute attentions
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
+ attention_scores = attention_scores * self.scale
+ attention_probs = nn.softmax(attention_scores, axis=2)
+
+ # attend to values
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
+
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ hidden_states = self.proj_attn(hidden_states)
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
+
+
+class FlaxBasicTransformerBlock(nn.Module):
+ r"""
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
+ https://arxiv.org/abs/1706.03762
+
+
+ Parameters:
+ dim (:obj:`int`):
+ Inner hidden states dimension
+ n_heads (:obj:`int`):
+ Number of heads
+ d_head (:obj:`int`):
+ Hidden states dimension inside each head
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ only_cross_attention (`bool`, defaults to `False`):
+ Whether to only apply cross attention.
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ """
+ dim: int
+ n_heads: int
+ d_head: int
+ dropout: float = 0.0
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+ use_memory_efficient_attention: bool = False
+
+ def setup(self):
+ # self attention (or cross_attention if only_cross_attention is True)
+ self.attn1 = FlaxAttention(
+ self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
+ )
+ # cross attention
+ self.attn2 = FlaxAttention(
+ self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
+ )
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ def __call__(self, hidden_states, context, deterministic=True):
+ # self attention
+ residual = hidden_states
+ if self.only_cross_attention:
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
+ else:
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ # cross attention
+ residual = hidden_states
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ # feed forward
+ residual = hidden_states
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
+
+
+class FlaxTransformer2DModel(nn.Module):
+ r"""
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
+ https://arxiv.org/pdf/1506.02025.pdf
+
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input number of channels
+ n_heads (:obj:`int`):
+ Number of heads
+ d_head (:obj:`int`):
+ Hidden states dimension inside each head
+ depth (:obj:`int`, *optional*, defaults to 1):
+ Number of transformers block
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ use_linear_projection (`bool`, defaults to `False`): tbd
+ only_cross_attention (`bool`, defaults to `False`): tbd
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ """
+ in_channels: int
+ n_heads: int
+ d_head: int
+ depth: int = 1
+ dropout: float = 0.0
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+ use_memory_efficient_attention: bool = False
+
+ def setup(self):
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+
+ inner_dim = self.n_heads * self.d_head
+ if self.use_linear_projection:
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
+ else:
+ self.proj_in = nn.Conv(
+ inner_dim,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ self.transformer_blocks = [
+ FlaxBasicTransformerBlock(
+ inner_dim,
+ self.n_heads,
+ self.d_head,
+ dropout=self.dropout,
+ only_cross_attention=self.only_cross_attention,
+ dtype=self.dtype,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ )
+ for _ in range(self.depth)
+ ]
+
+ if self.use_linear_projection:
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
+ else:
+ self.proj_out = nn.Conv(
+ inner_dim,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ def __call__(self, hidden_states, context, deterministic=True):
+ batch, height, width, channels = hidden_states.shape
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+ if self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
+ hidden_states = self.proj_in(hidden_states)
+ else:
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
+
+ for transformer_block in self.transformer_blocks:
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
+
+ if self.use_linear_projection:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
+ else:
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states + residual
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
+
+
+class FlaxFeedForward(nn.Module):
+ r"""
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
+ [`FeedForward`] class, with the following simplifications:
+ - The activation function is currently hardcoded to a gated linear unit from:
+ https://arxiv.org/abs/2002.05202
+ - `dim_out` is equal to `dim`.
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
+
+ Parameters:
+ dim (:obj:`int`):
+ Inner hidden states dimension
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ dim: int
+ dropout: float = 0.0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ # The second linear layer needs to be called
+ # net_2 for now to match the index of the Sequential layer
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
+ hidden_states = self.net_2(hidden_states)
+ return hidden_states
+
+
+class FlaxGEGLU(nn.Module):
+ r"""
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
+ https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim (:obj:`int`):
+ Input hidden states dimension
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ dim: int
+ dropout: float = 0.0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ inner_dim = self.dim * 4
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.proj(hidden_states)
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
diff --git a/diffusers/models/attention_processor.py b/diffusers/models/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..da2920fa671a31e489d3cc207d179e06015ae9f6
--- /dev/null
+++ b/diffusers/models/attention_processor.py
@@ -0,0 +1,1647 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..utils import deprecate, logging, maybe_allow_in_graph
+from ..utils.import_utils import is_xformers_available
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block=False,
+ processor: Optional["AttnProcessor"] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ):
+ is_lora = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
+ )
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ LoRAAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_lora:
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
+ processor = LoRAXFormersAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ # throw warning
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_lora:
+ attn_processor_class = (
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+ )
+ processor = attn_processor_class(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ rank=self.processor.rank,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ processor.to(self.processor.to_q_lora.up.weight.device)
+ elif is_custom_diffusion:
+ processor = CustomDiffusionAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor"):
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor):
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor, out_dim=3):
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(self, query, key, attention_mask=None):
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
+ if batch_size is None:
+ deprecate(
+ "batch_size=None",
+ "0.0.15",
+ (
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
+ " `prepare_attention_mask` when preparing the attention_mask."
+ ),
+ )
+ batch_size = 1
+
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
+class AttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRALinearLayer(nn.Module):
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
+ super().__init__()
+
+ if rank > min(in_features, out_features):
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
+
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states):
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+
+class LoRAAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing the LoRA attention mechanism.
+
+ Args:
+ hidden_size (`int`, *optional*):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the `encoder_hidden_states`.
+ rank (`int`, defaults to 4):
+ The dimension of the LoRA update matrices.
+ network_alpha (`int`, *optional*):
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.rank = rank
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
+
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv=True,
+ train_q_out=True,
+ hidden_size=None,
+ cross_attention_dim=None,
+ out_bias=True,
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor:
+ r"""
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
+ encoder.
+ """
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor2_0:
+ r"""
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
+ learnable key and value matrices for the text encoder.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query, out_dim=4)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key, out_dim=4)
+ value = attn.head_to_batch_dim(value, out_dim=4)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class LoRAAttnAddedKVProcessor(nn.Module):
+ r"""
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
+ encoder.
+
+ Args:
+ hidden_size (`int`, *optional*):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ rank (`int`, defaults to 4):
+ The dimension of the LoRA update matrices.
+
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.rank = rank
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
+ encoder_hidden_states
+ )
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
+ encoder_hidden_states
+ )
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnAddedKVProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ inner_dim = hidden_states.shape[-1]
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAXFormersAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
+
+ Args:
+ hidden_size (`int`, *optional*):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the `encoder_hidden_states`.
+ rank (`int`, defaults to 4):
+ The dimension of the LoRA update matrices.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ network_alpha (`int`, *optional*):
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+
+ """
+
+ def __init__(
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
+ ):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.rank = rank
+ self.attention_op = attention_op
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
+ query = attn.head_to_batch_dim(query).contiguous()
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
+
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAAttnProcessor2_0(nn.Module):
+ r"""
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
+ attention.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the `encoder_hidden_states`.
+ rank (`int`, defaults to 4):
+ The dimension of the LoRA update matrices.
+ network_alpha (`int`, *optional*):
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.rank = rank
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ inner_dim = hidden_states.shape[-1]
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionXFormersAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
+ """
+
+ def __init__(
+ self,
+ train_kv=True,
+ train_q_out=False,
+ hidden_size=None,
+ cross_attention_dim=None,
+ out_bias=True,
+ dropout=0.0,
+ attention_op: Optional[Callable] = None,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.attention_op = attention_op
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class SlicedAttnProcessor:
+ r"""
+ Processor for implementing sliced attention.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size):
+ self.slice_size = slice_size
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range(batch_size_attention // self.slice_size):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SlicedAttnAddedKVProcessor:
+ r"""
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size):
+ self.slice_size = slice_size
+
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range(batch_size_attention // self.slice_size):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+AttentionProcessor = Union[
+ AttnProcessor,
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+ LoRAAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ LoRAAttnAddedKVProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionXFormersAttnProcessor,
+]
+
+
+class SpatialNorm(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
+ """
+
+ def __init__(
+ self,
+ f_channels,
+ zq_channels,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f, zq):
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
diff --git a/diffusers/models/autoencoder_kl.py b/diffusers/models/autoencoder_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2390d2bc58261c76a38cd18dc48dbd7fb59a4d58
--- /dev/null
+++ b/diffusers/models/autoencoder_kl.py
@@ -0,0 +1,417 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import FromOriginalVAEMixin
+from ..utils import BaseOutput, apply_forward_hook
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .modeling_utils import ModelMixin
+from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Encoder, Decoder)):
+ module.gradient_checkpointing = value
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.enable_tiling(False)
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ @apply_forward_hook
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a, b, blend_extent):
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a, b, blend_extent):
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/diffusers/models/controlnet.py b/diffusers/models/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..354cd5851d4c203428b34db20698b3bcb0e4a626
--- /dev/null
+++ b/diffusers/models/controlnet.py
@@ -0,0 +1,823 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import FromOriginalControlnetMixin
+from ..utils import BaseOutput, logging
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from .unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
+ """
+ A ControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.FloatTensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ """
+ The [`ControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if "addition_embed_type" in self.config:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+
+ sample = sample + controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/diffusers/models/controlnet_composer.py b/diffusers/models/controlnet_composer.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ba3db9ede36664f2f366078276701529f3cda9
--- /dev/null
+++ b/diffusers/models/controlnet_composer.py
@@ -0,0 +1,886 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import FromOriginalControlnetMixin
+from ..utils import BaseOutput, logging
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from .unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
+ cond_num: int = 3,
+ fusion: str = "sum",
+ normalize_to_0_1: bool = True,
+ ):
+ super().__init__()
+
+ self.cond_num = cond_num
+ self.fusion = fusion
+ self.normalize_to_0_1 = normalize_to_0_1
+ self.conv_in_list = nn.ModuleList([])
+ for i in range(self.cond_num):
+ self.conv_in_list.append(nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1))
+ # self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks_list = nn.ModuleList([])
+ for i in range(self.cond_num):
+ blocks = nn.ModuleList([])
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+ self.blocks_list.append(blocks)
+
+ self.conv_out_list = nn.ModuleList([])
+ for i in range(self.cond_num):
+ self.conv_out_list.append(zero_module(nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)))
+
+ if self.fusion == "sum":
+ pass
+ elif self.fusion == "avg":
+ pass
+ elif self.fusion == "learn":
+ self.fusion_conv = zero_module(nn.Conv2d(conditioning_embedding_channels * self.cond_num, conditioning_embedding_channels, kernel_size=3, padding=1))
+ else:
+ assert False
+
+ def forward(self, conditioning, gating_matrix=None):
+ assert len(conditioning) == self.cond_num
+
+ if self.normalize_to_0_1:
+ conditioning = [x / 2. + 0.5 for x in conditioning]
+
+ embedding_list = []
+ for i in range(self.cond_num):
+ embedding = self.conv_in_list[i](conditioning[i])
+ embedding = F.silu(embedding)
+
+ for block in self.blocks_list[i]:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out_list[i](embedding)
+ embedding_list.append(embedding)
+
+ if gating_matrix is None:
+ if self.fusion == "sum" or self.fusion == "avg":
+ stacked_tensor = torch.stack(embedding_list, dim=0)
+ embedding = torch.sum(stacked_tensor, dim=0)
+ if self.fusion == "avg":
+ embedding = embedding / self.cond_num
+ elif self.fusion == "learn":
+ concat_tensor = torch.cat(embedding_list, dim=1)
+ embedding = self.fusion_conv(concat_tensor)
+ else:
+ assert False
+ else:
+ # embedding shape is (B, 3, C, H, W)
+ # gating matrix shape is (B, 3, H, W)
+ # (B, 3, C, H, W) x (B, 3, H, W) -> (B, C, H, W)
+ embedding = torch.stack(embedding_list, dim=1) # (B, 3, C, H, W)
+ gating_matrix = gating_matrix.unsqueeze(2) # (B, 3, 1, H, W)
+ embedding = embedding * gating_matrix
+ embedding = embedding.sum(dim=1)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
+ """
+ A ControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads=64,
+ cond_num: int = 3,
+ fusion: str = "sum",
+ normalize_to_0_1: bool = True,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ cond_num=cond_num,
+ fusion=fusion,
+ normalize_to_0_1=normalize_to_0_1,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ cond_num: int = 3,
+ fusion: str = "sum",
+ normalize_to_0_1: bool = True,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ cond_num=cond_num,
+ fusion=fusion,
+ normalize_to_0_1=normalize_to_0_1,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ # controlnet_cond: torch.FloatTensor,
+ controlnet_cond: List[torch.FloatTensor],
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ gating_matrix=None,
+ ) -> Union[ControlNetOutput, Tuple]:
+ """
+ The [`ControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ for cond in controlnet_cond:
+ cond = torch.flip(cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if "addition_embed_type" in self.config:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ cond_embedding = self.controlnet_cond_embedding(controlnet_cond, gating_matrix=gating_matrix)
+
+ sample = sample + cond_embedding
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/diffusers/models/controlnet_flax.py b/diffusers/models/controlnet_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..a826df48e41a632454c513877ec55be7f86089f9
--- /dev/null
+++ b/diffusers/models/controlnet_flax.py
@@ -0,0 +1,394 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple, Union
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+
+from ..configuration_utils import ConfigMixin, flax_register_to_config
+from ..utils import BaseOutput
+from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
+from .modeling_flax_utils import FlaxModelMixin
+from .unet_2d_blocks_flax import (
+ FlaxCrossAttnDownBlock2D,
+ FlaxDownBlock2D,
+ FlaxUNetMidBlock2DCrossAttn,
+)
+
+
+@flax.struct.dataclass
+class FlaxControlNetOutput(BaseOutput):
+ """
+ The output of [`FlaxControlNetModel`].
+
+ Args:
+ down_block_res_samples (`jnp.ndarray`):
+ mid_block_res_sample (`jnp.ndarray`):
+ """
+
+ down_block_res_samples: jnp.ndarray
+ mid_block_res_sample: jnp.ndarray
+
+
+class FlaxControlNetConditioningEmbedding(nn.Module):
+ conditioning_embedding_channels: int
+ block_out_channels: Tuple[int] = (16, 32, 96, 256)
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv_in = nn.Conv(
+ self.block_out_channels[0],
+ kernel_size=(3, 3),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ blocks = []
+ for i in range(len(self.block_out_channels) - 1):
+ channel_in = self.block_out_channels[i]
+ channel_out = self.block_out_channels[i + 1]
+ conv1 = nn.Conv(
+ channel_in,
+ kernel_size=(3, 3),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+ blocks.append(conv1)
+ conv2 = nn.Conv(
+ channel_out,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+ blocks.append(conv2)
+ self.blocks = blocks
+
+ self.conv_out = nn.Conv(
+ self.conditioning_embedding_channels,
+ kernel_size=(3, 3),
+ padding=((1, 1), (1, 1)),
+ kernel_init=nn.initializers.zeros_init(),
+ bias_init=nn.initializers.zeros_init(),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = nn.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = nn.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+@flax_register_to_config
+class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
+ r"""
+ A ControlNet model.
+
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
+ implemented for all models (such as downloading or saving).
+
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
+ general usage and behavior.
+
+ Inherent JAX features such as the following are supported:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ sample_size (`int`, *optional*):
+ The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4):
+ The number of channels in the input sample.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
+ The tuple of downsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
+ The dimension of the attention heads.
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
+ The number of attention heads.
+ cross_attention_dim (`int`, *optional*, defaults to 768):
+ The dimension of the cross attention features.
+ dropout (`float`, *optional*, defaults to 0):
+ Dropout probability for down, up and bottleneck blocks.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ """
+ sample_size: int = 32
+ in_channels: int = 4
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ )
+ only_cross_attention: Union[bool, Tuple[bool]] = False
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
+ layers_per_block: int = 2
+ attention_head_dim: Union[int, Tuple[int]] = 8
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
+ cross_attention_dim: int = 1280
+ dropout: float = 0.0
+ use_linear_projection: bool = False
+ dtype: jnp.dtype = jnp.float32
+ flip_sin_to_cos: bool = True
+ freq_shift: int = 0
+ controlnet_conditioning_channel_order: str = "rgb"
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
+
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ # init input tensors
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
+ controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
+ controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+ time_embed_dim = block_out_channels[0] * 4
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
+
+ # input
+ self.conv_in = nn.Conv(
+ block_out_channels[0],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # time
+ self.time_proj = FlaxTimesteps(
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
+ )
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
+
+ self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=self.conditioning_embedding_out_channels,
+ )
+
+ only_cross_attention = self.only_cross_attention
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
+
+ # down
+ down_blocks = []
+ controlnet_down_blocks = []
+
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv(
+ output_channel,
+ kernel_size=(1, 1),
+ padding="VALID",
+ kernel_init=nn.initializers.zeros_init(),
+ bias_init=nn.initializers.zeros_init(),
+ dtype=self.dtype,
+ )
+ controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(self.down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "CrossAttnDownBlock2D":
+ down_block = FlaxCrossAttnDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ num_attention_heads=num_attention_heads[i],
+ add_downsample=not is_final_block,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ dtype=self.dtype,
+ )
+ else:
+ down_block = FlaxDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ add_downsample=not is_final_block,
+ dtype=self.dtype,
+ )
+
+ down_blocks.append(down_block)
+
+ for _ in range(self.layers_per_block):
+ controlnet_block = nn.Conv(
+ output_channel,
+ kernel_size=(1, 1),
+ padding="VALID",
+ kernel_init=nn.initializers.zeros_init(),
+ bias_init=nn.initializers.zeros_init(),
+ dtype=self.dtype,
+ )
+ controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv(
+ output_channel,
+ kernel_size=(1, 1),
+ padding="VALID",
+ kernel_init=nn.initializers.zeros_init(),
+ bias_init=nn.initializers.zeros_init(),
+ dtype=self.dtype,
+ )
+ controlnet_down_blocks.append(controlnet_block)
+
+ self.down_blocks = down_blocks
+ self.controlnet_down_blocks = controlnet_down_blocks
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
+ in_channels=mid_block_channel,
+ dropout=self.dropout,
+ num_attention_heads=num_attention_heads[-1],
+ use_linear_projection=self.use_linear_projection,
+ dtype=self.dtype,
+ )
+
+ self.controlnet_mid_block = nn.Conv(
+ mid_block_channel,
+ kernel_size=(1, 1),
+ padding="VALID",
+ kernel_init=nn.initializers.zeros_init(),
+ bias_init=nn.initializers.zeros_init(),
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ sample,
+ timesteps,
+ encoder_hidden_states,
+ controlnet_cond,
+ conditioning_scale: float = 1.0,
+ return_dict: bool = True,
+ train: bool = False,
+ ) -> Union[FlaxControlNetOutput, Tuple]:
+ r"""
+ Args:
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
+ controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
+ conditioning_scale: (`float`) the scale factor for controlnet outputs
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
+ plain tuple.
+ train (`bool`, *optional*, defaults to `False`):
+ Use deterministic functions and disable dropout when not training.
+
+ Returns:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+ """
+ channel_order = self.controlnet_conditioning_channel_order
+ if channel_order == "bgr":
+ controlnet_cond = jnp.flip(controlnet_cond, axis=1)
+
+ # 1. time
+ if not isinstance(timesteps, jnp.ndarray):
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
+ timesteps = timesteps.astype(dtype=jnp.float32)
+ timesteps = jnp.expand_dims(timesteps, 0)
+
+ t_emb = self.time_proj(timesteps)
+ t_emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
+ sample = self.conv_in(sample)
+
+ controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample += controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for down_block in self.down_blocks:
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+ else:
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+
+ # 5. contronet blocks
+ controlnet_down_block_res_samples = ()
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample *= conditioning_scale
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return FlaxControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
diff --git a/diffusers/models/cross_attention.py b/diffusers/models/cross_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..44bc156b34cfa8536bdac0fee34709dfd66ae488
--- /dev/null
+++ b/diffusers/models/cross_attention.py
@@ -0,0 +1,94 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ..utils import deprecate
+from .attention_processor import ( # noqa: F401
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor2_0,
+ LoRAAttnProcessor,
+ LoRALinearLayer,
+ LoRAXFormersAttnProcessor,
+ SlicedAttnAddedKVProcessor,
+ SlicedAttnProcessor,
+ XFormersAttnProcessor,
+)
+from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
+
+
+deprecate(
+ "cross_attention",
+ "0.20.0",
+ "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
+ standard_warn=False,
+)
+
+
+AttnProcessor = AttentionProcessor
+
+
+class CrossAttention(Attention):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class CrossAttnProcessor(AttnProcessorRename):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class LoRACrossAttnProcessor(LoRAAttnProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class XFormersCrossAttnProcessor(XFormersAttnProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class SlicedCrossAttnProcessor(SlicedAttnProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
+
+
+class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
+ super().__init__(*args, **kwargs)
diff --git a/diffusers/models/dual_transformer_2d.py b/diffusers/models/dual_transformer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db7e73ca6afc5fa7c67c1902d79e67c1aa728bc
--- /dev/null
+++ b/diffusers/models/dual_transformer_2d.py
@@ -0,0 +1,151 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from torch import nn
+
+from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
+
+
+class DualTransformer2DModel(nn.Module):
+ """
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ ):
+ super().__init__()
+ self.transformers = nn.ModuleList(
+ [
+ Transformer2DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ num_layers=num_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ sample_size=sample_size,
+ num_vector_embeds=num_vector_embeds,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ )
+ for _ in range(2)
+ ]
+ )
+
+ # Variables that can be set by a pipeline:
+
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
+ self.mix_ratio = 0.5
+
+ # The shape of `encoder_hidden_states` is expected to be
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
+ self.condition_lengths = [77, 257]
+
+ # Which transformer to use to encode which condition.
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
+ self.transformer_index_for_condition = [1, 0]
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states,
+ timestep=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Optional attention mask to be applied in Attention
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ input_states = hidden_states
+
+ encoded_states = []
+ tokens_start = 0
+ # attention_mask is not used yet
+ for i in range(2):
+ # for each of the two transformers, pass the corresponding condition tokens
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
+ transformer_index = self.transformer_index_for_condition[i]
+ encoded_state = self.transformers[transformer_index](
+ input_states,
+ encoder_hidden_states=condition_state,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ encoded_states.append(encoded_state - input_states)
+ tokens_start += self.condition_lengths[i]
+
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
+ output_states = output_states + input_states
+
+ if not return_dict:
+ return (output_states,)
+
+ return Transformer2DModelOutput(sample=output_states)
diff --git a/diffusers/models/embeddings.py b/diffusers/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5a0c5549ee9d282b4eaa41d496255ad26b74699
--- /dev/null
+++ b/diffusers/models/embeddings.py
@@ -0,0 +1,546 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import nn
+
+from .activations import get_activation
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ return latent + self.pos_embed
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
+
+
+class LabelEmbedding(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+
+ Args:
+ num_classes (`int`): The number of classes.
+ hidden_size (`int`): The size of the vector embeddings.
+ dropout_prob (`float`): The probability of dropping a label.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = torch.tensor(force_drop_ids == 1)
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (self.training and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class TextImageProjection(nn.Module):
+ def __init__(
+ self,
+ text_embed_dim: int = 1024,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 10,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
+
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
+ batch_size = text_embeds.shape[0]
+
+ # image
+ image_text_embeds = self.image_embeds(image_embeds)
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+
+ # text
+ text_embeds = self.text_proj(text_embeds)
+
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
+
+
+class ImageProjection(nn.Module):
+ def __init__(
+ self,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 32,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.FloatTensor):
+ batch_size = image_embeds.shape[0]
+
+ # image
+ image_embeds = self.image_embeds(image_embeds)
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+ image_embeds = self.norm(image_embeds)
+ return image_embeds
+
+
+class CombinedTimestepLabelEmbeddings(nn.Module):
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ class_labels = self.class_embedder(class_labels) # (N, D)
+
+ conditioning = timesteps_emb + class_labels # (N, D)
+
+ return conditioning
+
+
+class TextTimeEmbedding(nn.Module):
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(encoder_dim)
+ self.pool = AttentionPooling(num_heads, encoder_dim)
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
+ self.norm2 = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, hidden_states):
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.pool(hidden_states)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class TextImageTimeEmbedding(nn.Module):
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
+ self.text_norm = nn.LayerNorm(time_embed_dim)
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
+ # text
+ time_text_embeds = self.text_proj(text_embeds)
+ time_text_embeds = self.text_norm(time_text_embeds)
+
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+
+ return time_image_embeds + time_text_embeds
+
+
+class ImageTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, image_embeds: torch.FloatTensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ return time_image_embeds
+
+
+class ImageHintTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+ self.input_hint_block = nn.Sequential(
+ nn.Conv2d(3, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(256, 4, 3, padding=1),
+ )
+
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ hint = self.input_hint_block(hint)
+ return time_image_embeds, hint
+
+
+class AttentionPooling(nn.Module):
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
+
+ def __init__(self, num_heads, embed_dim, dtype=None):
+ super().__init__()
+ self.dtype = dtype
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.num_heads = num_heads
+ self.dim_per_head = embed_dim // self.num_heads
+
+ def forward(self, x):
+ bs, length, width = x.size()
+
+ def shape(x):
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
+ x = x.transpose(1, 2)
+ return x
+
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
+
+ # (bs*n_heads, class_token_length, dim_per_head)
+ q = shape(self.q_proj(class_token))
+ # (bs*n_heads, length+class_token_length, dim_per_head)
+ k = shape(self.k_proj(x))
+ v = shape(self.v_proj(x))
+
+ # (bs*n_heads, class_token_length, length+class_token_length):
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # (bs*n_heads, dim_per_head, class_token_length)
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ # (bs, length+1, width)
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
+
+ return a[:, 0, :] # cls_token
diff --git a/diffusers/models/embeddings_flax.py b/diffusers/models/embeddings_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..88c2c45e4655b8013fa96e0b4408e3ec0a87c2c7
--- /dev/null
+++ b/diffusers/models/embeddings_flax.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import flax.linen as nn
+import jax.numpy as jnp
+
+
+def get_sinusoidal_embeddings(
+ timesteps: jnp.ndarray,
+ embedding_dim: int,
+ freq_shift: float = 1,
+ min_timescale: float = 1,
+ max_timescale: float = 1.0e4,
+ flip_sin_to_cos: bool = False,
+ scale: float = 1.0,
+) -> jnp.ndarray:
+ """Returns the positional encoding (same as Tensor2Tensor).
+
+ Args:
+ timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ embedding_dim: The number of output channels.
+ min_timescale: The smallest time unit (should probably be 0.0).
+ max_timescale: The largest time unit.
+ Returns:
+ a Tensor of timing signals [N, num_channels]
+ """
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
+ num_timescales = float(embedding_dim // 2)
+ log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
+
+ # scale embeddings
+ scaled_time = scale * emb
+
+ if flip_sin_to_cos:
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
+ else:
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
+ return signal
+
+
+class FlaxTimestepEmbedding(nn.Module):
+ r"""
+ Time step Embedding Module. Learns embeddings for input time steps.
+
+ Args:
+ time_embed_dim (`int`, *optional*, defaults to `32`):
+ Time step embedding dimension
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ time_embed_dim: int = 32
+ dtype: jnp.dtype = jnp.float32
+
+ @nn.compact
+ def __call__(self, temb):
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
+ temb = nn.silu(temb)
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
+ return temb
+
+
+class FlaxTimesteps(nn.Module):
+ r"""
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
+
+ Args:
+ dim (`int`, *optional*, defaults to `32`):
+ Time step embedding dimension
+ """
+ dim: int = 32
+ flip_sin_to_cos: bool = False
+ freq_shift: float = 1
+
+ @nn.compact
+ def __call__(self, timesteps):
+ return get_sinusoidal_embeddings(
+ timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
+ )
diff --git a/diffusers/models/modeling_flax_pytorch_utils.py b/diffusers/models/modeling_flax_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9de83f87dab84d2e7fdd77b835db787cb4f1cb6
--- /dev/null
+++ b/diffusers/models/modeling_flax_pytorch_utils.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch - Flax general utilities."""
+import re
+
+import jax.numpy as jnp
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def rename_key(key):
+ regex = r"\w+[.]\d+"
+ pats = re.findall(regex, key)
+ for pat in pats:
+ key = key.replace(pat, "_".join(pat.split(".")))
+ return key
+
+
+#####################
+# PyTorch => Flax #
+#####################
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
+# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
+def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
+
+ # conv norm or layer norm
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ if (
+ any("norm" in str_ for str_ in pt_tuple_key)
+ and (pt_tuple_key[-1] == "bias")
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
+ ):
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ return renamed_pt_tuple_key, pt_tensor
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # embedding
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # conv layer
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # linear layer
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+ if pt_tuple_key[-1] == "weight":
+ pt_tensor = pt_tensor.T
+ return renamed_pt_tuple_key, pt_tensor
+
+ # old PyTorch layer norm weight
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
+ if pt_tuple_key[-1] == "gamma":
+ return renamed_pt_tuple_key, pt_tensor
+
+ # old PyTorch layer norm bias
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
+ if pt_tuple_key[-1] == "beta":
+ return renamed_pt_tuple_key, pt_tensor
+
+ return pt_tuple_key, pt_tensor
+
+
+def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
+ # Step 1: Convert pytorch tensor to numpy
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
+
+ # Step 2: Since the model is stateless, get random Flax params
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
+
+ random_flax_state_dict = flatten_dict(random_flax_params)
+ flax_state_dict = {}
+
+ # Need to change some parameters name to match Flax names
+ for pt_key, pt_tensor in pt_state_dict.items():
+ renamed_pt_key = rename_key(pt_key)
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
+
+ # Correctly rename weight parameters
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
+
+ if flax_key in random_flax_state_dict:
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
+ raise ValueError(
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
+ )
+
+ # also add unexpected weight so that warning is thrown
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
+
+ return unflatten_dict(flax_state_dict)
diff --git a/diffusers/models/modeling_flax_utils.py b/diffusers/models/modeling_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6e1b3bba3d94e0252794cd0eda079f2c6f4183
--- /dev/null
+++ b/diffusers/models/modeling_flax_utils.py
@@ -0,0 +1,534 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pickle import UnpicklingError
+from typing import Any, Dict, Union
+
+import jax
+import jax.numpy as jnp
+import msgpack.exceptions
+from flax.core.frozen_dict import FrozenDict, unfreeze
+from flax.serialization import from_bytes, to_bytes
+from flax.traverse_util import flatten_dict, unflatten_dict
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from .. import __version__, is_torch_available
+from ..utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ FLAX_WEIGHTS_NAME,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ WEIGHTS_NAME,
+ logging,
+)
+from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
+
+
+logger = logging.get_logger(__name__)
+
+
+class FlaxModelMixin:
+ r"""
+ Base class for all Flax models.
+
+ [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
+ saving models.
+
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _flax_internal_args = ["name", "parent", "dtype"]
+
+ @classmethod
+ def _from_config(cls, config, **kwargs):
+ """
+ All context managers that the model should be initialized under go here.
+ """
+ return cls(config, **kwargs)
+
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
+ """
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
+ """
+
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
+ def conditional_cast(param):
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
+ param = param.astype(dtype)
+ return param
+
+ if mask is None:
+ return jax.tree_map(conditional_cast, params)
+
+ flat_params = flatten_dict(params)
+ flat_mask, _ = jax.tree_flatten(mask)
+
+ for masked, key in zip(flat_mask, flat_params.keys()):
+ if masked:
+ param = flat_params[key]
+ flat_params[key] = conditional_cast(param)
+
+ return unflatten_dict(flat_params)
+
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
+ the `params` in place.
+
+ This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
+ for params you want to cast, and `False` for those you want to skip.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # load model
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
+ >>> params = model.to_bf16(params)
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> flat_params = traverse_util.flatten_dict(params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> params = model.to_bf16(params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
+
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
+ for params you want to cast, and `False` for those you want to skip.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # Download model and configuration from huggingface.co
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
+ >>> # we'll first cast to fp16 and back to fp32
+ >>> params = model.to_f16(params)
+ >>> # now cast back to fp32
+ >>> params = model.to_fp32(params)
+ ```"""
+ return self._cast_floating_to(params, jnp.float32, mask)
+
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
+ `params` in place.
+
+ This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
+ for params you want to cast, and `False` for those you want to skip.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # load model
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model params will be in fp32, to cast these to float16
+ >>> params = model.to_fp16(params)
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> flat_params = traverse_util.flatten_dict(params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> params = model.to_fp16(params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.float16, mask)
+
+ def init_weights(self, rng: jax.random.KeyArray) -> Dict:
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ dtype: jnp.dtype = jnp.float32,
+ *model_args,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a pretrained Flax model from a pretrained model configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ using [`~FlaxModelMixin.save_pretrained`].
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified, all the computation will be performed with the given `dtype`.
+
+
+
+ This only specifies the dtype of the *computation* and does not influence the dtype of model
+ parameters.
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
+ [`~FlaxModelMixin.to_bf16`].
+
+
+
+ model_args (sequence of positional arguments, *optional*):
+ All remaining positional arguments are passed to the underlying model's `__init__` method.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch checkpoint save file.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it is loaded) and initiate the model (for
+ example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
+ model's `__init__` method (we assume all relevant updates to the configuration have already been
+ done).
+ - If a configuration is not provided, `kwargs` are first passed to the configuration class
+ initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
+ to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
+ Remaining keys that do not correspond to any configuration attribute are passed to the underlying
+ model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
+ ```
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```bash
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+ """
+ config = kwargs.pop("config", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ from_pt = kwargs.pop("from_pt", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "flax",
+ }
+
+ # Load config if we don't provide a configuration
+ config_path = config if config is not None else pretrained_model_name_or_path
+ model, model_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ # model args
+ dtype=dtype,
+ **kwargs,
+ )
+
+ # Load model
+ pretrained_path_with_subfolder = (
+ pretrained_model_name_or_path
+ if subfolder is None
+ else os.path.join(pretrained_model_name_or_path, subfolder)
+ )
+ if os.path.isdir(pretrained_path_with_subfolder):
+ if from_pt:
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
+ )
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
+ # Load from a Flax checkpoint
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
+ # Check if pytorch weights exist instead
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
+ " using `from_pt=True`."
+ )
+ else:
+ raise EnvironmentError(
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+ f"{pretrained_path_with_subfolder}."
+ )
+ else:
+ try:
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
+ f"{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
+ " internet connection or see how to run the library in offline mode at"
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+ )
+
+ if from_pt:
+ if is_torch_available():
+ from .modeling_utils import load_state_dict
+ else:
+ raise EnvironmentError(
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
+ "Please, install PyTorch or use native Flax weights."
+ )
+
+ # Step 1: Get the pytorch file
+ pytorch_model_file = load_state_dict(model_file)
+
+ # Step 2: Convert the weights
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
+ else:
+ try:
+ with open(model_file, "rb") as state_f:
+ state = from_bytes(cls, state_f.read())
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+ try:
+ with open(model_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
+ # make sure all arrays are stored as jnp.ndarray
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
+ # https://github.com/google/flax/issues/1261
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
+
+ # flatten dicts
+ state = flatten_dict(state)
+
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
+
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
+
+ missing_keys = required_params - set(state.keys())
+ unexpected_keys = set(state.keys()) - required_params
+
+ if missing_keys:
+ logger.warning(
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
+ "Make sure to call model.init_weights to initialize the missing weights."
+ )
+ cls._missing_keys = missing_keys
+
+ for key in state.keys():
+ if key in shape_state and state[key].shape != shape_state[key].shape:
+ raise ValueError(
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
+ )
+
+ # remove unexpected keys to not be saved again
+ for unexpected_key in unexpected_keys:
+ del state[unexpected_key]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ else:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
+ )
+
+ return model, unflatten_dict(state)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ params: Union[Dict, FrozenDict],
+ is_main_process: bool = True,
+ ):
+ """
+ Save a model and its configuration file to a directory so that it can be reloaded using the
+ [`~FlaxModelMixin.from_pretrained`] class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # save model
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
+ with open(output_model_file, "wb") as f:
+ model_bytes = to_bytes(params)
+ f.write(model_bytes)
+
+ logger.info(f"Model weights saved in {output_model_file}")
diff --git a/diffusers/models/modeling_pytorch_flax_utils.py b/diffusers/models/modeling_pytorch_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a61638ad02f7a38a1439f35dea5966c7c7d519d8
--- /dev/null
+++ b/diffusers/models/modeling_pytorch_flax_utils.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch - Flax general utilities."""
+
+from pickle import UnpicklingError
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.serialization import from_bytes
+from flax.traverse_util import flatten_dict
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+#####################
+# Flax => PyTorch #
+#####################
+
+
+# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
+def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
+ try:
+ with open(model_file, "rb") as flax_state_f:
+ flax_state = from_bytes(None, flax_state_f.read())
+ except UnpicklingError as e:
+ try:
+ with open(model_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
+
+ return load_flax_weights_in_pytorch_model(pt_model, flax_state)
+
+
+def load_flax_weights_in_pytorch_model(pt_model, flax_state):
+ """Load flax checkpoints in a PyTorch model"""
+
+ try:
+ import torch # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
+ " instructions."
+ )
+ raise
+
+ # check if we have bf16 weights
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
+ if any(is_type_bf16):
+ # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
+
+ # and bf16 is not fully supported in PT yet.
+ logger.warning(
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
+ "before loading those in PyTorch model."
+ )
+ flax_state = jax.tree_util.tree_map(
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
+ )
+
+ pt_model.base_model_prefix = ""
+
+ flax_state_dict = flatten_dict(flax_state, sep=".")
+ pt_model_dict = pt_model.state_dict()
+
+ # keep track of unexpected & missing keys
+ unexpected_keys = []
+ missing_keys = set(pt_model_dict.keys())
+
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
+ flax_key_tuple_array = flax_key_tuple.split(".")
+
+ if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
+ elif flax_key_tuple_array[-1] == "kernel":
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+ flax_tensor = flax_tensor.T
+ elif flax_key_tuple_array[-1] == "scale":
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+
+ if "time_embedding" not in flax_key_tuple_array:
+ for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
+ flax_key_tuple_array[i] = (
+ flax_key_tuple_string.replace("_0", ".0")
+ .replace("_1", ".1")
+ .replace("_2", ".2")
+ .replace("_3", ".3")
+ .replace("_4", ".4")
+ .replace("_5", ".5")
+ .replace("_6", ".6")
+ .replace("_7", ".7")
+ .replace("_8", ".8")
+ .replace("_9", ".9")
+ )
+
+ flax_key = ".".join(flax_key_tuple_array)
+
+ if flax_key in pt_model_dict:
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
+ raise ValueError(
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
+ )
+ else:
+ # add weight to pytorch dict
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
+ # remove from missing keys
+ missing_keys.remove(flax_key)
+ else:
+ # weight is not expected by PyTorch model
+ unexpected_keys.append(flax_key)
+
+ pt_model.load_state_dict(pt_model_dict)
+
+ # re-transform missing_keys to list
+ missing_keys = list(missing_keys)
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
+ " FlaxBertForSequenceClassification model)."
+ )
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
+ " use it for predictions and inference."
+ )
+
+ return pt_model
diff --git a/diffusers/models/modeling_utils.py b/diffusers/models/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fa96514c0a9e39b9321550f6d85a8e11b0deb36
--- /dev/null
+++ b/diffusers/models/modeling_utils.py
@@ -0,0 +1,980 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import itertools
+import os
+import re
+from functools import partial
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, device, nn
+
+from .. import __version__
+from ..utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ FLAX_WEIGHTS_NAME,
+ HF_HUB_OFFLINE,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ _add_variant,
+ _get_model_file,
+ deprecate,
+ is_accelerate_available,
+ is_safetensors_available,
+ is_torch_version,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_version(">=", "1.9.0"):
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
+else:
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
+
+
+if is_accelerate_available():
+ import accelerate
+ from accelerate.utils import set_module_tensor_to_device
+ from accelerate.utils.versions import is_torch_version
+
+if is_safetensors_available():
+ import safetensors
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
+ return next(parameters_and_buffers).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ params = tuple(parameter.parameters())
+ if len(params) > 0:
+ return params[0].dtype
+
+ buffers = tuple(parameter.buffers())
+ if len(buffers) > 0:
+ return buffers[0].dtype
+
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
+ """
+ Reads a checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
+ return torch.load(checkpoint_file, map_location="cpu")
+ else:
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
+ saving models.
+
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _supports_gradient_checkpointing = False
+ _keys_to_ignore_on_load_unexpected = None
+
+ def __init__(self):
+ super().__init__()
+
+ def __getattr__(self, name: str) -> Any:
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
+ """
+
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
+ is_attribute = name in self.__dict__
+
+ if is_in_config and not is_attribute:
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
+ return self._internal_dict[name]
+
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
+ return super().__getattr__(name)
+
+ @property
+ def is_gradient_checkpointing(self) -> bool:
+ """
+ Whether gradient checkpointing is activated for this model or not.
+ """
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+ def enable_gradient_checkpointing(self):
+ """
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
+ *checkpoint activations* in other frameworks).
+ """
+ if not self._supports_gradient_checkpointing:
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+ def disable_gradient_checkpointing(self):
+ """
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
+ *checkpoint activations* in other frameworks).
+ """
+ if self._supports_gradient_checkpointing:
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+ r"""
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
+ inference. Speed up during training is not guaranteed.
+
+
+
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
+ precedent.
+
+
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import UNet2DConditionModel
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> model = UNet2DConditionModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
+ ... )
+ >>> model = model.to("cuda")
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory so that it can be reloaded using the
+ [`~models.ModelMixin.from_pretrained`] class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format `pytorch_model..bin`.
+ """
+ if safe_serialization and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+ weights_name = _add_variant(weights_name, variant)
+
+ # Save the model
+ if safe_serialization:
+ safetensors.torch.save_file(
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
+ )
+ else:
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
+
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
+ train the model, set it back in training mode with `model.train()`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`~ModelMixin.save_pretrained`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info (`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if `device_map` contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
+ weights. If set to `False`, `safetensors` weights are not loaded.
+
+
+
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
+ `huggingface-cli login`. You can also activate the special
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ firewalled environment.
+
+
+
+ Example:
+
+ ```py
+ from diffusers import UNet2DConditionModel
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ ```
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```bash
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ from_flax = kwargs.pop("from_flax", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ max_memory = kwargs.pop("max_memory", None)
+ offload_folder = kwargs.pop("offload_folder", None)
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if device_map is not None and not is_accelerate_available():
+ raise NotImplementedError(
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
+ )
+
+ # Check if we can handle device_map and dispatching the weights
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ # load model
+ model_file = None
+ if from_flax:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=FLAX_WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ # Convert the weights
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
+
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
+ else:
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ if low_cpu_mem_usage:
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ model = cls.from_config(config, **unused_kwargs)
+
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
+ if device_map is None:
+ param_device = "cpu"
+ state_dict = load_state_dict(model_file, variant=variant)
+ model._convert_deprecated_attention_blocks(state_dict)
+ # move the params from meta device to cpu
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if len(missing_keys) > 0:
+ raise ValueError(
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
+ " those weights or else make sure your checkpoint file is correct."
+ )
+ unexpected_keys = []
+
+ empty_state_dict = model.state_dict()
+ for param_name, param in state_dict.items():
+ accepts_dtype = "dtype" in set(
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
+ )
+
+ if param_name not in empty_state_dict:
+ unexpected_keys.append(param_name)
+ continue
+
+ if empty_state_dict[param_name].shape != param.shape:
+ raise ValueError(
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
+ )
+
+ if accepts_dtype:
+ set_module_tensor_to_device(
+ model, param_name, param_device, value=param, dtype=torch_dtype
+ )
+ else:
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
+
+ if cls._keys_to_ignore_on_load_unexpected is not None:
+ for pat in cls._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ logger.warn(
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
+ )
+
+ else: # else let accelerate handle loading and dispatching.
+ # Load weights and dispatch according to the device_map
+ # by default the device_map is None and the weights are loaded on the CPU
+ try:
+ accelerate.load_checkpoint_and_dispatch(
+ model,
+ model_file,
+ device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ dtype=torch_dtype,
+ )
+ except AttributeError as e:
+ # When using accelerate loading, we do not have the ability to load the state
+ # dict and rename the weight names manually. Additionally, accelerate skips
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
+ # (which look like they should be private variables?), so we can't use the standard hooks
+ # to rename parameters on load. We need to mimic the original weight names so the correct
+ # attributes are available. After we have loaded the weights, we convert the deprecated
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
+ # the weights so we don't have to do this again.
+
+ if "'Attention' object has no attribute" in str(e):
+ logger.warn(
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
+ " please also re-upload it or open a PR on the original repository."
+ )
+ model._temp_convert_self_to_deprecated_attention_blocks()
+ accelerate.load_checkpoint_and_dispatch(
+ model,
+ model_file,
+ device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ dtype=torch_dtype,
+ )
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
+ else:
+ raise e
+
+ loading_info = {
+ "missing_keys": [],
+ "unexpected_keys": [],
+ "mismatched_keys": [],
+ "error_msgs": [],
+ }
+ else:
+ model = cls.from_config(config, **unused_kwargs)
+
+ state_dict = load_state_dict(model_file, variant=variant)
+ model._convert_deprecated_attention_blocks(state_dict)
+
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = list(state_dict.keys())
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (trainable or non-embedding) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters.
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embedding parameters.
+
+ Returns:
+ `int`: The number of parameters.
+
+ Example:
+
+ ```py
+ from diffusers import UNet2DConditionModel
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
+ unet.num_parameters(only_trainable=True)
+ 859520964
+ ```
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+ def _convert_deprecated_attention_blocks(self, state_dict):
+ deprecated_attention_block_paths = []
+
+ def recursive_find_attn_block(name, module):
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
+ deprecated_attention_block_paths.append(name)
+
+ for sub_name, sub_module in module.named_children():
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
+ recursive_find_attn_block(sub_name, sub_module)
+
+ recursive_find_attn_block("", self)
+
+ # NOTE: we have to check if the deprecated parameters are in the state dict
+ # because it is possible we are loading from a state dict that was already
+ # converted
+
+ for path in deprecated_attention_block_paths:
+ # group_norm path stays the same
+
+ # query -> to_q
+ if f"{path}.query.weight" in state_dict:
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
+ if f"{path}.query.bias" in state_dict:
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
+
+ # key -> to_k
+ if f"{path}.key.weight" in state_dict:
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
+ if f"{path}.key.bias" in state_dict:
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
+
+ # value -> to_v
+ if f"{path}.value.weight" in state_dict:
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
+ if f"{path}.value.bias" in state_dict:
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
+
+ # proj_attn -> to_out.0
+ if f"{path}.proj_attn.weight" in state_dict:
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
+ if f"{path}.proj_attn.bias" in state_dict:
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
+
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
+ deprecated_attention_block_modules = []
+
+ def recursive_find_attn_block(module):
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
+ deprecated_attention_block_modules.append(module)
+
+ for sub_module in module.children():
+ recursive_find_attn_block(sub_module)
+
+ recursive_find_attn_block(self)
+
+ for module in deprecated_attention_block_modules:
+ module.query = module.to_q
+ module.key = module.to_k
+ module.value = module.to_v
+ module.proj_attn = module.to_out[0]
+
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
+ # that _all_ the weights are loaded into the new attributes and we're not
+ # making an incorrect assumption that this model should be converted when
+ # it really shouldn't be.
+ del module.to_q
+ del module.to_k
+ del module.to_v
+ del module.to_out
+
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
+ deprecated_attention_block_modules = []
+
+ def recursive_find_attn_block(module):
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
+ deprecated_attention_block_modules.append(module)
+
+ for sub_module in module.children():
+ recursive_find_attn_block(sub_module)
+
+ recursive_find_attn_block(self)
+
+ for module in deprecated_attention_block_modules:
+ module.to_q = module.query
+ module.to_k = module.key
+ module.to_v = module.value
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
+
+ del module.query
+ del module.key
+ del module.value
+ del module.proj_attn
diff --git a/diffusers/models/prior_transformer.py b/diffusers/models/prior_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f3c61dd7561742114947e3419c19fec8c2a824f
--- /dev/null
+++ b/diffusers/models/prior_transformer.py
@@ -0,0 +1,364 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .attention import BasicTransformerBlock
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+
+
+@dataclass
+class PriorTransformerOutput(BaseOutput):
+ """
+ The output of [`PriorTransformer`].
+
+ Args:
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
+ """
+
+ predicted_image_embedding: torch.FloatTensor
+
+
+class PriorTransformer(ModelMixin, ConfigMixin):
+ """
+ A Prior Transformer model.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
+ num_embeddings (`int`, *optional*, defaults to 77):
+ The number of embeddings of the model input `hidden_states`
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
+ additional_embeddings`.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
+ The activation function to use to create timestep embeddings.
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
+ needed.
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
+ `encoder_hidden_states` is `None`.
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
+ product between the text embedding and image embedding as proposed in the unclip paper
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
+ If None, will be set to `num_attention_heads * attention_head_dim`
+ embedding_proj_dim (`int`, *optional*, default to None):
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
+ clip_embed_dim (`int`, *optional*, default to None):
+ The dimension of the output. If None, will be set to `embedding_dim`.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 32,
+ attention_head_dim: int = 64,
+ num_layers: int = 20,
+ embedding_dim: int = 768,
+ num_embeddings=77,
+ additional_embeddings=4,
+ dropout: float = 0.0,
+ time_embed_act_fn: str = "silu",
+ norm_in_type: Optional[str] = None, # layer
+ embedding_proj_norm_type: Optional[str] = None, # layer
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
+ added_emb_type: Optional[str] = "prd", # prd
+ time_embed_dim: Optional[int] = None,
+ embedding_proj_dim: Optional[int] = None,
+ clip_embed_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.additional_embeddings = additional_embeddings
+
+ time_embed_dim = time_embed_dim or inner_dim
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
+ clip_embed_dim = clip_embed_dim or embedding_dim
+
+ self.time_proj = Timesteps(inner_dim, True, 0)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
+
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
+
+ if embedding_proj_norm_type is None:
+ self.embedding_proj_norm = None
+ elif embedding_proj_norm_type == "layer":
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
+ else:
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
+
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
+
+ if encoder_hid_proj_type is None:
+ self.encoder_hidden_states_proj = None
+ elif encoder_hid_proj_type == "linear":
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
+ else:
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
+
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
+
+ if added_emb_type == "prd":
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
+ elif added_emb_type is None:
+ self.prd_embedding = None
+ else:
+ raise ValueError(
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ activation_fn="gelu",
+ attention_bias=True,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ if norm_in_type == "layer":
+ self.norm_in = nn.LayerNorm(inner_dim)
+ elif norm_in_type is None:
+ self.norm_in = None
+ else:
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
+
+ self.norm_out = nn.LayerNorm(inner_dim)
+
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
+
+ causal_attention_mask = torch.full(
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
+ )
+ causal_attention_mask.triu_(1)
+ causal_attention_mask = causal_attention_mask[None, ...]
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
+
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def forward(
+ self,
+ hidden_states,
+ timestep: Union[torch.Tensor, float, int],
+ proj_embedding: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`PriorTransformer`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
+ The currently predicted image embeddings.
+ timestep (`torch.LongTensor`):
+ Current denoising step.
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
+ Projected embedding vector the denoising process is conditioned on.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
+ Hidden states of the text embeddings the denoising process is conditioned on.
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
+ Text mask for the text embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ batch_size = hidden_states.shape[0]
+
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(hidden_states.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
+
+ timesteps_projected = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might be fp16, so we need to cast here.
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
+ time_embeddings = self.time_embedding(timesteps_projected)
+
+ if self.embedding_proj_norm is not None:
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
+
+ proj_embeddings = self.embedding_proj(proj_embedding)
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
+
+ hidden_states = self.proj_in(hidden_states)
+
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
+
+ additional_embeds = []
+ additional_embeddings_len = 0
+
+ if encoder_hidden_states is not None:
+ additional_embeds.append(encoder_hidden_states)
+ additional_embeddings_len += encoder_hidden_states.shape[1]
+
+ if len(proj_embeddings.shape) == 2:
+ proj_embeddings = proj_embeddings[:, None, :]
+
+ if len(hidden_states.shape) == 2:
+ hidden_states = hidden_states[:, None, :]
+
+ additional_embeds = additional_embeds + [
+ proj_embeddings,
+ time_embeddings[:, None, :],
+ hidden_states,
+ ]
+
+ if self.prd_embedding is not None:
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
+ additional_embeds.append(prd_embedding)
+
+ hidden_states = torch.cat(
+ additional_embeds,
+ dim=1,
+ )
+
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
+ positional_embeddings = F.pad(
+ positional_embeddings,
+ (
+ 0,
+ 0,
+ additional_embeddings_len,
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
+ ),
+ value=0.0,
+ )
+
+ hidden_states = hidden_states + positional_embeddings
+
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
+
+ if self.norm_in is not None:
+ hidden_states = self.norm_in(hidden_states)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
+
+ hidden_states = self.norm_out(hidden_states)
+
+ if self.prd_embedding is not None:
+ hidden_states = hidden_states[:, -1]
+ else:
+ hidden_states = hidden_states[:, additional_embeddings_len:]
+
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
+
+ if not return_dict:
+ return (predicted_image_embedding,)
+
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
+
+ def post_process_latents(self, prior_latents):
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
+ return prior_latents
diff --git a/diffusers/models/resnet.py b/diffusers/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..24c3b07e7cb65447ad996b00066d42a74700dd97
--- /dev/null
+++ b/diffusers/models/resnet.py
@@ -0,0 +1,877 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .activations import get_activation
+from .attention import AdaGroupNorm
+from .attention_processor import SpatialNorm
+
+
+class Upsample1D(nn.Module):
+ """A 1D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, inputs):
+ assert inputs.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(inputs)
+
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
+
+ if self.use_conv:
+ outputs = self.conv(outputs)
+
+ return outputs
+
+
+class Downsample1D(nn.Module):
+ """A 1D downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ padding (`int`, default `1`):
+ padding for the convolution.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
+
+ def forward(self, inputs):
+ assert inputs.shape[1] == self.channels
+ return self.conv(inputs)
+
+
+class Upsample2D(nn.Module):
+ """A 2D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+ # https://github.com/pytorch/pytorch/issues/86679
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample2D(nn.Module):
+ """A 2D downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ padding (`int`, default `1`):
+ padding for the convolution.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class FirUpsample2D(nn.Module):
+ """A 2D FIR upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
+ kernel for the FIR filter.
+ """
+
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
+ datatype as `hidden_states`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ output_shape = (
+ (hidden_states.shape[2] - 1) * factor + convH,
+ (hidden_states.shape[3] - 1) * factor + convW,
+ )
+ output_padding = (
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ num_groups = hidden_states.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ inverse_conv = F.conv_transpose2d(
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
+ )
+
+ output = upfirdn2d_native(
+ inverse_conv,
+ torch.tensor(kernel, device=inverse_conv.device),
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
+ )
+ else:
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
+ )
+
+ return output
+
+ def forward(self, hidden_states):
+ if self.use_conv:
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ """A 2D FIR downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
+ kernel for the FIR filter.
+ """
+
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight:
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
+ performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
+ same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
+ stride_value = [factor, factor]
+ upfirdn_input = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
+ else:
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ down=factor,
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
+
+ return output
+
+ def forward(self, hidden_states):
+ if self.use_conv:
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
+
+ return hidden_states
+
+
+# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
+class KDownsample2D(nn.Module):
+ def __init__(self, pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
+ self.pad = kernel_1d.shape[1] // 2 - 1
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
+
+ def forward(self, inputs):
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
+ weight[indices, indices] = kernel
+ return F.conv2d(inputs, weight, stride=2)
+
+
+class KUpsample2D(nn.Module):
+ def __init__(self, pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
+ self.pad = kernel_1d.shape[1] // 2 - 1
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
+
+ def forward(self, inputs):
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
+ weight[indices, indices] = kernel
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
+
+
+class ResnetBlock2D(nn.Module):
+ r"""
+ A Resnet block.
+
+ Parameters:
+ in_channels (`int`): The number of channels in the input.
+ out_channels (`int`, *optional*, default to be `None`):
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
+ groups_out (`int`, *optional*, default to None):
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
+ "ada_group" for a stronger conditioning with scale and shift.
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
+ use_in_shortcut (`bool`, *optional*, default to `True`):
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
+ `conv_shortcut` output.
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
+ If None, same as `out_channels`.
+ """
+
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ skip_time_act=False,
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
+ kernel=None,
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ up=False,
+ down=False,
+ conv_shortcut_bias: bool = True,
+ conv_2d_out_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+ self.time_embedding_norm = time_embedding_norm
+ self.skip_time_act = skip_time_act
+
+ if groups_out is None:
+ groups_out = groups
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
+ else:
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ elif self.time_embedding_norm == "scale_shift":
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ self.time_emb_proj = None
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+ else:
+ self.time_emb_proj = None
+
+ if self.time_embedding_norm == "ada_group":
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
+ elif self.time_embedding_norm == "spatial":
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
+ else:
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
+ self.dropout = torch.nn.Dropout(dropout)
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
+ )
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm1(hidden_states, temb)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if self.time_emb_proj is not None:
+ if not self.skip_time_act:
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb)[:, :, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
+ hidden_states = self.norm2(hidden_states, temb)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+# unet_rl.py
+def rearrange_dims(tensor):
+ if len(tensor.shape) == 2:
+ return tensor[:, :, None]
+ if len(tensor.shape) == 3:
+ return tensor[:, :, None, :]
+ elif len(tensor.shape) == 4:
+ return tensor[:, :, 0, :]
+ else:
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
+
+
+class Conv1dBlock(nn.Module):
+ """
+ Conv1d --> GroupNorm --> Mish
+ """
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
+ self.mish = nn.Mish()
+
+ def forward(self, inputs):
+ intermediate_repr = self.conv1d(inputs)
+ intermediate_repr = rearrange_dims(intermediate_repr)
+ intermediate_repr = self.group_norm(intermediate_repr)
+ intermediate_repr = rearrange_dims(intermediate_repr)
+ output = self.mish(intermediate_repr)
+ return output
+
+
+# unet_rl.py
+class ResidualTemporalBlock1D(nn.Module):
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
+ super().__init__()
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
+
+ self.time_emb_act = nn.Mish()
+ self.time_emb = nn.Linear(embed_dim, out_channels)
+
+ self.residual_conv = (
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
+ )
+
+ def forward(self, inputs, t):
+ """
+ Args:
+ inputs : [ batch_size x inp_channels x horizon ]
+ t : [ batch_size x embed_dim ]
+
+ returns:
+ out : [ batch_size x out_channels x horizon ]
+ """
+ t = self.time_emb_act(t)
+ t = self.time_emb(t)
+ out = self.conv_in(inputs) + rearrange_dims(t)
+ out = self.conv_out(out)
+ return out + self.residual_conv(inputs)
+
+
+def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
+ a: multiple of the upsampling factor.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ kernel.to(device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
+ )
+ return output
+
+
+def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
+ )
+ return output
+
+
+def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = tensor.shape
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = tensor.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(tensor.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
+
+
+class TemporalConvLayer(nn.Module):
+ """
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
+ """
+
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
+ super().__init__()
+ out_dim = out_dim or in_dim
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
+ )
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+ )
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+ )
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
+ )
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, hidden_states, num_frames=1):
+ hidden_states = (
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
+ )
+
+ identity = hidden_states
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.conv3(hidden_states)
+ hidden_states = self.conv4(hidden_states)
+
+ hidden_states = identity + hidden_states
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
+ )
+ return hidden_states
diff --git a/diffusers/models/resnet_flax.py b/diffusers/models/resnet_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a391f4b947e74beda03f26e376141b2b3c21502
--- /dev/null
+++ b/diffusers/models/resnet_flax.py
@@ -0,0 +1,124 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+
+class FlaxUpsample2D(nn.Module):
+ out_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ batch, height, width, channels = hidden_states.shape
+ hidden_states = jax.image.resize(
+ hidden_states,
+ shape=(batch, height * 2, width * 2, channels),
+ method="nearest",
+ )
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxDownsample2D(nn.Module):
+ out_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding=((1, 1), (1, 1)), # padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
+ # hidden_states = jnp.pad(hidden_states, pad_width=pad)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxResnetBlock2D(nn.Module):
+ in_channels: int
+ out_channels: int = None
+ dropout_prob: float = 0.0
+ use_nin_shortcut: bool = None
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.conv1 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
+
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.dropout = nn.Dropout(self.dropout_prob)
+ self.conv2 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
+
+ self.conv_shortcut = None
+ if use_nin_shortcut:
+ self.conv_shortcut = nn.Conv(
+ out_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states, temb, deterministic=True):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ temb = self.time_emb_proj(nn.swish(temb))
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
diff --git a/diffusers/models/t5_film_transformer.py b/diffusers/models/t5_film_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c41e656a9dbe81edafd5a2958d49ff28e84fd01
--- /dev/null
+++ b/diffusers/models/t5_film_transformer.py
@@ -0,0 +1,321 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import torch
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .attention_processor import Attention
+from .embeddings import get_timestep_embedding
+from .modeling_utils import ModelMixin
+
+
+class T5FilmDecoder(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ input_dims: int = 128,
+ targets_length: int = 256,
+ max_decoder_noise_time: float = 2000.0,
+ d_model: int = 768,
+ num_layers: int = 12,
+ num_heads: int = 12,
+ d_kv: int = 64,
+ d_ff: int = 2048,
+ dropout_rate: float = 0.1,
+ ):
+ super().__init__()
+
+ self.conditioning_emb = nn.Sequential(
+ nn.Linear(d_model, d_model * 4, bias=False),
+ nn.SiLU(),
+ nn.Linear(d_model * 4, d_model * 4, bias=False),
+ nn.SiLU(),
+ )
+
+ self.position_encoding = nn.Embedding(targets_length, d_model)
+ self.position_encoding.weight.requires_grad = False
+
+ self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ self.decoders = nn.ModuleList()
+ for lyr_num in range(num_layers):
+ # FiLM conditional T5 decoder
+ lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
+ self.decoders.append(lyr)
+
+ self.decoder_norm = T5LayerNorm(d_model)
+
+ self.post_dropout = nn.Dropout(p=dropout_rate)
+ self.spec_out = nn.Linear(d_model, input_dims, bias=False)
+
+ def encoder_decoder_mask(self, query_input, key_input):
+ mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
+ return mask.unsqueeze(-3)
+
+ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
+ batch, _, _ = decoder_input_tokens.shape
+ assert decoder_noise_time.shape == (batch,)
+
+ # decoder_noise_time is in [0, 1), so rescale to expected timing range.
+ time_steps = get_timestep_embedding(
+ decoder_noise_time * self.config.max_decoder_noise_time,
+ embedding_dim=self.config.d_model,
+ max_period=self.config.max_decoder_noise_time,
+ ).to(dtype=self.dtype)
+
+ conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
+
+ assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
+
+ seq_length = decoder_input_tokens.shape[1]
+
+ # If we want to use relative positions for audio context, we can just offset
+ # this sequence by the length of encodings_and_masks.
+ decoder_positions = torch.broadcast_to(
+ torch.arange(seq_length, device=decoder_input_tokens.device),
+ (batch, seq_length),
+ )
+
+ position_encodings = self.position_encoding(decoder_positions)
+
+ inputs = self.continuous_inputs_projection(decoder_input_tokens)
+ inputs += position_encodings
+ y = self.dropout(inputs)
+
+ # decoder: No padding present.
+ decoder_mask = torch.ones(
+ decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
+ )
+
+ # Translate encoding masks to encoder-decoder masks.
+ encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
+
+ # cross attend style: concat encodings
+ encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
+ encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
+
+ for lyr in self.decoders:
+ y = lyr(
+ y,
+ conditioning_emb=conditioning_emb,
+ encoder_hidden_states=encoded,
+ encoder_attention_mask=encoder_decoder_mask,
+ )[0]
+
+ y = self.decoder_norm(y)
+ y = self.post_dropout(y)
+
+ spec_out = self.spec_out(y)
+ return spec_out
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
+ super().__init__()
+ self.layer = nn.ModuleList()
+
+ # cond self attention: layer 0
+ self.layer.append(
+ T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
+ )
+
+ # cross attention: layer 1
+ self.layer.append(
+ T5LayerCrossAttention(
+ d_model=d_model,
+ d_kv=d_kv,
+ num_heads=num_heads,
+ dropout_rate=dropout_rate,
+ layer_norm_epsilon=layer_norm_epsilon,
+ )
+ )
+
+ # Film Cond MLP + dropout: last layer
+ self.layer.append(
+ T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ conditioning_emb=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ ):
+ hidden_states = self.layer[0](
+ hidden_states,
+ conditioning_emb=conditioning_emb,
+ attention_mask=attention_mask,
+ )
+
+ if encoder_hidden_states is not None:
+ encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
+ encoder_hidden_states.dtype
+ )
+
+ hidden_states = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_extended_attention_mask,
+ )
+
+ # Apply Film Conditional Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states, conditioning_emb)
+
+ return (hidden_states,)
+
+
+class T5LayerSelfAttentionCond(nn.Module):
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate):
+ super().__init__()
+ self.layer_norm = T5LayerNorm(d_model)
+ self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ conditioning_emb=None,
+ attention_mask=None,
+ ):
+ # pre_self_attention_layer_norm
+ normed_hidden_states = self.layer_norm(hidden_states)
+
+ if conditioning_emb is not None:
+ normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
+
+ # Self-attention block
+ attention_output = self.attention(normed_hidden_states)
+
+ hidden_states = hidden_states + self.dropout(attention_output)
+
+ return hidden_states
+
+
+class T5LayerCrossAttention(nn.Module):
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
+ super().__init__()
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states=None,
+ attention_mask=None,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.attention(
+ normed_hidden_states,
+ encoder_hidden_states=key_value_states,
+ attention_mask=attention_mask.squeeze(1),
+ )
+ layer_output = hidden_states + self.dropout(attention_output)
+ return layer_output
+
+
+class T5LayerFFCond(nn.Module):
+ def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
+ super().__init__()
+ self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
+ self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_states, conditioning_emb=None):
+ forwarded_states = self.layer_norm(hidden_states)
+ if conditioning_emb is not None:
+ forwarded_states = self.film(forwarded_states, conditioning_emb)
+
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = hidden_states + self.dropout(forwarded_states)
+ return hidden_states
+
+
+class T5DenseGatedActDense(nn.Module):
+ def __init__(self, d_model, d_ff, dropout_rate):
+ super().__init__()
+ self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
+ self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.act = NewGELUActivation()
+
+ def forward(self, hidden_states):
+ hidden_gelu = self.act(self.wi_0(hidden_states))
+ hidden_linear = self.wi_1(hidden_states)
+ hidden_states = hidden_gelu * hidden_linear
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+class T5LayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class T5FiLMLayer(nn.Module):
+ """
+ FiLM Layer
+ """
+
+ def __init__(self, in_features, out_features):
+ super().__init__()
+ self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
+
+ def forward(self, x, conditioning_emb):
+ emb = self.scale_bias(conditioning_emb)
+ scale, shift = torch.chunk(emb, 2, -1)
+ x = x * (1 + scale) + shift
+ return x
diff --git a/diffusers/models/transformer_2d.py b/diffusers/models/transformer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..83da16838ae2248c31faada9cd5704d20500459c
--- /dev/null
+++ b/diffusers/models/transformer_2d.py
@@ -0,0 +1,341 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..models.embeddings import ImagePositionalEmbeddings
+from ..utils import BaseOutput, deprecate
+from .attention import BasicTransformerBlock
+from .embeddings import PatchEmbed
+from .modeling_utils import ModelMixin
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches:
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/diffusers/models/transformer_temporal.py b/diffusers/models/transformer_temporal.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfafdb055bcfedc911b0a19d1e5da8089a18b215
--- /dev/null
+++ b/diffusers/models/transformer_temporal.py
@@ -0,0 +1,179 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .attention import BasicTransformerBlock
+from .modeling_utils import ModelMixin
+
+
+@dataclass
+class TransformerTemporalModelOutput(BaseOutput):
+ """
+ The output of [`TransformerTemporalModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input.
+ """
+
+ sample: torch.FloatTensor
+
+
+class TransformerTemporalModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
+ double_self_attention (`bool`, *optional*):
+ Configure if each `TransformerBlock` should contain two self-attention layers.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ norm_elementwise_affine: bool = True,
+ double_self_attention: bool = True,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ double_self_attention=double_self_attention,
+ norm_elementwise_affine=norm_elementwise_affine,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ class_labels=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`TransformerTemporal`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input hidden_states.
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
+ returned, otherwise a `tuple` where the first element is the sample tensor.
+ """
+ # 1. Input
+ batch_frames, channel, height, width = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ residual = hidden_states
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states[None, None, :]
+ .reshape(batch_size, height, width, channel, num_frames)
+ .permute(0, 3, 4, 1, 2)
+ .contiguous()
+ )
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
+
+ output = hidden_states + residual
+
+ if not return_dict:
+ return (output,)
+
+ return TransformerTemporalModelOutput(sample=output)
diff --git a/diffusers/models/unet_1d.py b/diffusers/models/unet_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b617388f3917c97e8aef39ec0f386eb2e4a1254
--- /dev/null
+++ b/diffusers/models/unet_1d.py
@@ -0,0 +1,255 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
+
+
+@dataclass
+class UNet1DOutput(BaseOutput):
+ """
+ The output of [`UNet1DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
+ The hidden states output from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet1DModel(ModelMixin, ConfigMixin):
+ r"""
+ A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
+ in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
+ extra_in_channels (`int`, *optional*, defaults to 0):
+ Number of additional channels to be added to the input of the first down block. Useful for cases where the
+ input data has more channels than what the model was initially designed for.
+ time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip sin to cos for Fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
+ Tuple of block output channels.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
+ out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
+ act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
+ norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
+ layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
+ downsample_each_block (`int`, *optional*, defaults to `False`):
+ Experimental feature for using a UNet without upsampling.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 65536,
+ sample_rate: Optional[int] = None,
+ in_channels: int = 2,
+ out_channels: int = 2,
+ extra_in_channels: int = 0,
+ time_embedding_type: str = "fourier",
+ flip_sin_to_cos: bool = True,
+ use_timestep_embedding: bool = False,
+ freq_shift: float = 0.0,
+ down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
+ out_block_type: str = None,
+ block_out_channels: Tuple[int] = (32, 32, 64),
+ act_fn: str = None,
+ norm_num_groups: int = 8,
+ layers_per_block: int = 1,
+ downsample_each_block: bool = False,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(
+ embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
+ )
+ timestep_input_dim = block_out_channels[0]
+
+ if use_timestep_embedding:
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=timestep_input_dim,
+ time_embed_dim=time_embed_dim,
+ act_fn=act_fn,
+ out_dim=block_out_channels[0],
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+ self.out_block = None
+
+ # down
+ output_channel = in_channels
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+
+ if i == 0:
+ input_channel += extra_in_channels
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_downsample=not is_final_block or downsample_each_block,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ in_channels=block_out_channels[-1],
+ mid_channels=block_out_channels[-1],
+ out_channels=block_out_channels[-1],
+ embed_dim=block_out_channels[0],
+ num_layers=layers_per_block,
+ add_downsample=downsample_each_block,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ if out_block_type is None:
+ final_upsample_channels = out_channels
+ else:
+ final_upsample_channels = block_out_channels[0]
+
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = (
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
+ )
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_upsample=not is_final_block,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.out_block = get_out_block(
+ out_block_type=out_block_type,
+ num_groups_out=num_groups_out,
+ embed_dim=block_out_channels[0],
+ out_channels=out_channels,
+ act_fn=act_fn,
+ fc_dim=block_out_channels[-1] // 4,
+ )
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet1DOutput, Tuple]:
+ r"""
+ The [`UNet1DModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is the sample tensor.
+ """
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ timestep_embed = self.time_proj(timesteps)
+ if self.config.use_timestep_embedding:
+ timestep_embed = self.time_mlp(timestep_embed)
+ else:
+ timestep_embed = timestep_embed[..., None]
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
+ timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
+
+ # 2. down
+ down_block_res_samples = ()
+ for downsample_block in self.down_blocks:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
+ down_block_res_samples += res_samples
+
+ # 3. mid
+ if self.mid_block:
+ sample = self.mid_block(sample, timestep_embed)
+
+ # 4. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-1:]
+ down_block_res_samples = down_block_res_samples[:-1]
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
+
+ # 5. post-process
+ if self.out_block:
+ sample = self.out_block(sample, timestep_embed)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet1DOutput(sample=sample)
diff --git a/diffusers/models/unet_1d_blocks.py b/diffusers/models/unet_1d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ae48e0f8c4f3da6132a02c3e89f7c976a2b150
--- /dev/null
+++ b/diffusers/models/unet_1d_blocks.py
@@ -0,0 +1,656 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .activations import get_activation
+from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
+
+
+class DownResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ conv_shortcut=False,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.add_downsample = add_downsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity is None:
+ self.nonlinearity = None
+ else:
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states, output_states
+
+
+class UpResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.time_embedding_norm = time_embedding_norm
+ self.add_upsample = add_upsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity is None:
+ self.nonlinearity = None
+ else:
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
+
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
+ if res_hidden_states_tuple is not None:
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ hidden_states = self.upsample(hidden_states)
+
+ return hidden_states
+
+
+class ValueFunctionMidBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, embed_dim):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.embed_dim = embed_dim
+
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
+
+ def forward(self, x, temb=None):
+ x = self.res1(x, temb)
+ x = self.down1(x)
+ x = self.res2(x, temb)
+ x = self.down2(x)
+ return x
+
+
+class MidResTemporalBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ embed_dim,
+ num_layers: int = 1,
+ add_downsample: bool = False,
+ add_upsample: bool = False,
+ non_linearity=None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.add_downsample = add_downsample
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity is None:
+ self.nonlinearity = None
+ else:
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Downsample1D(out_channels, use_conv=True)
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True)
+
+ if self.upsample and self.downsample:
+ raise ValueError("Block cannot downsample and upsample")
+
+ def forward(self, hidden_states, temb):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsample:
+ hidden_states = self.upsample(hidden_states)
+ if self.downsample:
+ self.downsample = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class OutConv1DBlock(nn.Module):
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
+ super().__init__()
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
+ self.final_conv1d_act = get_activation(act_fn)
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.final_conv1d_1(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_gn(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_act(hidden_states)
+ hidden_states = self.final_conv1d_2(hidden_states)
+ return hidden_states
+
+
+class OutValueFunctionBlock(nn.Module):
+ def __init__(self, fc_dim, embed_dim, act_fn="mish"):
+ super().__init__()
+ self.final_block = nn.ModuleList(
+ [
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
+ get_activation(act_fn),
+ nn.Linear(fc_dim // 2, 1),
+ ]
+ )
+
+ def forward(self, hidden_states, temb):
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
+ for layer in self.final_block:
+ hidden_states = layer(hidden_states)
+
+ return hidden_states
+
+
+_kernels = {
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
+ "lanczos3": [
+ 0.003689131001010537,
+ 0.015056144446134567,
+ -0.03399861603975296,
+ -0.066637322306633,
+ 0.13550527393817902,
+ 0.44638532400131226,
+ 0.44638532400131226,
+ 0.13550527393817902,
+ -0.066637322306633,
+ -0.03399861603975296,
+ 0.015056144446134567,
+ 0.003689131001010537,
+ ],
+}
+
+
+class Downsample1d(nn.Module):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel])
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer("kernel", kernel_1d)
+
+ def forward(self, hidden_states):
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
+ weight[indices, indices] = kernel
+ return F.conv1d(hidden_states, weight, stride=2)
+
+
+class Upsample1d(nn.Module):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer("kernel", kernel_1d)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
+ weight[indices, indices] = kernel
+ return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
+
+
+class SelfAttention1d(nn.Module):
+ def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
+ super().__init__()
+ self.channels = in_channels
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
+ self.num_heads = n_head
+
+ self.query = nn.Linear(self.channels, self.channels)
+ self.key = nn.Linear(self.channels, self.channels)
+ self.value = nn.Linear(self.channels, self.channels)
+
+ self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
+
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel_dim, seq = hidden_states.shape
+
+ hidden_states = self.group_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores, dim=-1)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.dropout(hidden_states)
+
+ output = hidden_states + residual
+
+ return output
+
+
+class ResConvBlock(nn.Module):
+ def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
+ super().__init__()
+ self.is_last = is_last
+ self.has_conv_skip = in_channels != out_channels
+
+ if self.has_conv_skip:
+ self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
+
+ self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
+ self.gelu_1 = nn.GELU()
+ self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
+
+ if not self.is_last:
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
+ self.gelu_2 = nn.GELU()
+
+ def forward(self, hidden_states):
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
+
+ hidden_states = self.conv_1(hidden_states)
+ hidden_states = self.group_norm_1(hidden_states)
+ hidden_states = self.gelu_1(hidden_states)
+ hidden_states = self.conv_2(hidden_states)
+
+ if not self.is_last:
+ hidden_states = self.group_norm_2(hidden_states)
+ hidden_states = self.gelu_2(hidden_states)
+
+ output = hidden_states + residual
+ return output
+
+
+class UNetMidBlock1D(nn.Module):
+ def __init__(self, mid_channels, in_channels, out_channels=None):
+ super().__init__()
+
+ out_channels = in_channels if out_channels is None else out_channels
+
+ # there is always at least one resnet
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+ self.up = Upsample1d(kernel="cubic")
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+ for attn, resnet in zip(self.attentions, self.resnets):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownBlock1D(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class DownBlock1D(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class DownBlock1DNoSkip(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = torch.cat([hidden_states, temb], dim=1)
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class AttnUpBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.up = Upsample1d(kernel="cubic")
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = in_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+ self.up = Upsample1d(kernel="cubic")
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock1DNoSkip(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = in_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
+ if down_block_type == "DownResnetBlock1D":
+ return DownResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "DownBlock1D":
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "AttnDownBlock1D":
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "DownBlock1DNoSkip":
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
+ if up_block_type == "UpResnetBlock1D":
+ return UpResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "UpBlock1D":
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "AttnUpBlock1D":
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "UpBlock1DNoSkip":
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
+ if mid_block_type == "MidResTemporalBlock1D":
+ return MidResTemporalBlock1D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ embed_dim=embed_dim,
+ add_downsample=add_downsample,
+ )
+ elif mid_block_type == "ValueFunctionMidBlock1D":
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
+ elif mid_block_type == "UNetMidBlock1D":
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
+ raise ValueError(f"{mid_block_type} does not exist.")
+
+
+def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
+ if out_block_type == "OutConv1DBlock":
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
+ elif out_block_type == "ValueFunction":
+ return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
+ return None
diff --git a/diffusers/models/unet_2d.py b/diffusers/models/unet_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b17acd3d829519465ec0d8daa41b16184aa70f2
--- /dev/null
+++ b/diffusers/models/unet_2d.py
@@ -0,0 +1,329 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ The output of [`UNet2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
+ 1)`.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip sin to cos for Fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
+ Tuple of downsample block types.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
+ Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ downsample_type (`str`, *optional*, defaults to `conv`):
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
+ upsample_type (`str`, *optional*, defaults to `conv`):
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, or `"identity"`.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
+ conditioning with `class_embed_type` equal to `None`.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor: float = 1,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ upsample_type: str = "conv",
+ act_fn: str = "silu",
+ attention_head_dim: Optional[int] = 8,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ resnet_time_scale_shift: str = "default",
+ add_attention: bool = True,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ add_attention=add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ class_labels: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ r"""
+ The [`UNet2DModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when doing class conditioning")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_blocks.py b/diffusers/models/unet_2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..469e501b814b4673ab7f18378aecb348cebbfcdf
--- /dev/null
+++ b/diffusers/models/unet_2d_blocks.py
@@ -0,0 +1,3182 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..utils import is_torch_version, logging
+from .attention import AdaGroupNorm
+from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from .dual_transformer_2d import DualTransformer2DModel
+from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
+from .transformer_2d import Transformer2DModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ downsample_type=None,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ if add_downsample is False:
+ downsample_type = None
+ else:
+ downsample_type = downsample_type or "conv" # default to 'conv'
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ upsample_type=None,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ if add_upsample is False:
+ upsample_type = None
+ else:
+ upsample_type = upsample_type or "conv" # default to 'conv'
+
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DSimpleCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ self.attention_head_dim = attention_head_dim
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ self.num_heads = in_channels // self.attention_head_dim
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ cross_attention_dim=in_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ # attn
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ # resnet
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ downsample_type="conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ self.downsample_type = downsample_type
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_type == "conv":
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ elif downsample_type == "resnet":
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None, upsample_size=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ if self.downsample_type == "resnet":
+ hidden_states = downsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals=None,
+ ):
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class ResnetDownsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ skip_time_act=False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class SimpleCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ resnets = []
+ attentions = []
+
+ self.attention_head_dim = attention_head_dim
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ output_states = ()
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ mask,
+ cross_attention_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class KDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ add_downsample=False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ # YiYi's comments- might be able to use FirDownsample2D, look into details later
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class KCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ cross_attention_dim: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_group_size: int = 32,
+ add_downsample=True,
+ attention_head_dim: int = 64,
+ add_self_attention: bool = False,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ out_channels,
+ out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ group_size=resnet_group_size,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ cross_attention_kwargs,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.downsamplers is None:
+ output_states += (None,)
+ else:
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ upsample_type="conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.upsample_type = upsample_type
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_type == "conv":
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ elif upsample_type == "resnet":
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ if self.upsample_type == "resnet":
+ hidden_states = upsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temb_channels=None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temb_channels=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=temb)
+ hidden_states = attn(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class ResnetUpsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ skip_time_act=False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class SimpleCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # resnet
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ mask,
+ cross_attention_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class KUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 5,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: Optional[int] = 32,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ k_in_channels = 2 * out_channels
+ k_out_channels = in_channels
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class KCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ attention_head_dim=1, # attention dim_head
+ cross_attention_dim: int = 768,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ is_first_block = in_channels == out_channels == temb_channels
+ is_middle_block = in_channels != out_channels
+ add_self_attention = True if is_first_block else False
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ # in_channels, and out_channels for the block (k-unet)
+ k_in_channels = out_channels if is_first_block else 2 * out_channels
+ k_out_channels = in_channels
+
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ if is_middle_block and (i == num_layers - 1):
+ conv_2d_out_channels = k_out_channels
+ else:
+ conv_2d_out_channels = None
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ conv_2d_out_channels=conv_2d_out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ k_out_channels if (i == num_layers - 1) else out_channels,
+ k_out_channels // attention_head_dim
+ if (i == num_layers - 1)
+ else out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ cross_attention_kwargs,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+# can potentially later be renamed to `No-feed-forward` attention
+class KAttentionBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ temb_channels: int = 768, # for ada_group_norm
+ add_self_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ group_size: int = 32,
+ ):
+ super().__init__()
+ self.add_self_attention = add_self_attention
+
+ # 1. Self-Attn
+ if add_self_attention:
+ self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ cross_attention_norm=None,
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+
+ def _to_3d(self, hidden_states, height, weight):
+ return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
+
+ def _to_4d(self, hidden_states, height, weight):
+ return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ # TODO: mark emb as non-optional (self.norm2 requires it).
+ # requires assessing impact of change to positional param interface.
+ emb: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ # 1. Self-Attention
+ if self.add_self_attention:
+ norm_hidden_states = self.norm1(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention/None
+ norm_hidden_states = self.norm2(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
diff --git a/diffusers/models/unet_2d_blocks_flax.py b/diffusers/models/unet_2d_blocks_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d1447570dda34b814bdc1660dfd37874fed0125
--- /dev/null
+++ b/diffusers/models/unet_2d_blocks_flax.py
@@ -0,0 +1,377 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import flax.linen as nn
+import jax.numpy as jnp
+
+from .attention_flax import FlaxTransformer2DModel
+from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
+
+
+class FlaxCrossAttnDownBlock2D(nn.Module):
+ r"""
+ Cross Attention 2D Downsizing block - original architecture from Unet transformers:
+ https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ num_attention_heads: int = 1
+ add_downsample: bool = True
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ use_memory_efficient_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ attentions = []
+
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.out_channels,
+ n_heads=self.num_attention_heads,
+ d_head=self.out_channels // self.num_attention_heads,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=self.only_cross_attention,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+ output_states += (hidden_states,)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class FlaxDownBlock2D(nn.Module):
+ r"""
+ Flax 2D downsizing block
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ add_downsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+ self.resnets = resnets
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, temb, deterministic=True):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ output_states += (hidden_states,)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class FlaxCrossAttnUpBlock2D(nn.Module):
+ r"""
+ Cross Attention 2D Upsampling block - original architecture from Unet transformers:
+ https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add upsampling layer before each final output
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ prev_output_channel: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ num_attention_heads: int = 1
+ add_upsample: bool = True
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ use_memory_efficient_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ attentions = []
+
+ for i in range(self.num_layers):
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.out_channels,
+ n_heads=self.num_attention_heads,
+ d_head=self.out_channels // self.num_attention_heads,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=self.only_cross_attention,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
+
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUpBlock2D(nn.Module):
+ r"""
+ Flax 2D upsampling block
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ prev_output_channel (:obj:`int`):
+ Output channels from the previous block
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ prev_output_channel: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ add_upsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+
+ for i in range(self.num_layers):
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
+
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUNetMidBlock2DCrossAttn(nn.Module):
+ r"""
+ Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ num_attention_heads: int = 1
+ use_linear_projection: bool = False
+ use_memory_efficient_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ # there is always at least one resnet
+ resnets = [
+ FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ ]
+
+ attentions = []
+
+ for _ in range(self.num_layers):
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.in_channels,
+ n_heads=self.num_attention_heads,
+ d_head=self.in_channels // self.num_attention_heads,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+
+ return hidden_states
diff --git a/diffusers/models/unet_2d_condition.py b/diffusers/models/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..c76552fb3cd3445ebc93b54679380bbc1cc8b09a
--- /dev/null
+++ b/diffusers/models/unet_2d_condition.py
@@ -0,0 +1,1101 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import BaseOutput, logging
+from .activations import get_activation
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ UpBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+from safetensors import safe_open
+import safetensors
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = 256,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ size_cond=False,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ # self.size_cond = size_cond
+ # if self.size_cond:
+ # size_embed_dim = block_out_channels[0] * 2
+ # if size_embed_dim % 2 != 0:
+ # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.")
+ # self.size_h_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # self.size_w_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # size_input_dim = size_embed_dim
+ # self.H_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+ # self.W_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(addition_time_embed_dim * 8, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False):
+ import os
+ import json
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors
+ model_postfix = SAFETENSORS_WEIGHTS_NAME
+ if fp16:
+ model_postfix = "diffusion_pytorch_model.fp16.safetensors"
+ else:
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+ model_postfix = WEIGHTS_NAME
+
+ config["size_cond"] = size_cond
+ if size_cond:
+ # config.addition_embed_type = "time"
+ config["addition_embed_type"] = "time"
+
+ model = cls.from_config(config)
+
+ if weight is None:
+ print("Loading from pretrained SD")
+ model_file = os.path.join(pretrained_model_path, model_postfix)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ # state_dict = {}
+ # with safe_open(model_file, framework="pt", device=0) as f:
+ # for k in f.keys():
+ # state_dict[k] = f.get_tensor(k)
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ for k, v in model.state_dict().items():
+ # print(k)
+ if k not in state_dict.keys():
+ state_dict.update({k: v})
+ model.load_state_dict(state_dict)
+
+ # conv_in_weights = model.conv_in.weight.clone()
+ # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1))
+ # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # with torch.no_grad():
+ # model.conv_in.weight[:, :4] = conv_in_weights # original weights
+ # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero
+
+ return model
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ # H=None,
+ # W=None,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # if self.size_cond:
+ # h_emb = self.H_embedding(self.size_h_proj(H))
+ # w_emb = self.W_embedding(self.size_w_proj(W))
+ # emb = emb + torch.cat((h_emb, w_emb), dim=1)
+
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "time":
+ time_ids = added_cond_kwargs.get("time_ids")
+ batch_size = time_ids.shape[0]
+ # time_embeds = self.add_time_proj(time_ids)
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+
+ add_embeds = time_embeds
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_condition_flax.py b/diffusers/models/unet_2d_condition_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..de39bc75d2e392a423c9ea09e979b9f42d818dc1
--- /dev/null
+++ b/diffusers/models/unet_2d_condition_flax.py
@@ -0,0 +1,357 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple, Union
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+
+from ..configuration_utils import ConfigMixin, flax_register_to_config
+from ..utils import BaseOutput
+from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
+from .modeling_flax_utils import FlaxModelMixin
+from .unet_2d_blocks_flax import (
+ FlaxCrossAttnDownBlock2D,
+ FlaxCrossAttnUpBlock2D,
+ FlaxDownBlock2D,
+ FlaxUNetMidBlock2DCrossAttn,
+ FlaxUpBlock2D,
+)
+
+
+@flax.struct.dataclass
+class FlaxUNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`FlaxUNet2DConditionModel`].
+
+ Args:
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: jnp.ndarray
+
+
+@flax_register_to_config
+class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
+ implemented for all models (such as downloading or saving).
+
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
+ general usage and behavior.
+
+ Inherent JAX features such as the following are supported:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ sample_size (`int`, *optional*):
+ The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4):
+ The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4):
+ The number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
+ The dimension of the attention heads.
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
+ The number of attention heads.
+ cross_attention_dim (`int`, *optional*, defaults to 768):
+ The dimension of the cross attention features.
+ dropout (`float`, *optional*, defaults to 0):
+ Dropout probability for down, up and bottleneck blocks.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
+ Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
+ """
+
+ sample_size: int = 32
+ in_channels: int = 4
+ out_channels: int = 4
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ )
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
+ only_cross_attention: Union[bool, Tuple[bool]] = False
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
+ layers_per_block: int = 2
+ attention_head_dim: Union[int, Tuple[int]] = 8
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
+ cross_attention_dim: int = 1280
+ dropout: float = 0.0
+ use_linear_projection: bool = False
+ dtype: jnp.dtype = jnp.float32
+ flip_sin_to_cos: bool = True
+ freq_shift: int = 0
+ use_memory_efficient_attention: bool = False
+
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ # init input tensors
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+ time_embed_dim = block_out_channels[0] * 4
+
+ if self.num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
+
+ # input
+ self.conv_in = nn.Conv(
+ block_out_channels[0],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # time
+ self.time_proj = FlaxTimesteps(
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
+ )
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
+
+ only_cross_attention = self.only_cross_attention
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
+
+ # down
+ down_blocks = []
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(self.down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "CrossAttnDownBlock2D":
+ down_block = FlaxCrossAttnDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ num_attention_heads=num_attention_heads[i],
+ add_downsample=not is_final_block,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+ else:
+ down_block = FlaxDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ add_downsample=not is_final_block,
+ dtype=self.dtype,
+ )
+
+ down_blocks.append(down_block)
+ self.down_blocks = down_blocks
+
+ # mid
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ dropout=self.dropout,
+ num_attention_heads=num_attention_heads[-1],
+ use_linear_projection=self.use_linear_projection,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+
+ # up
+ up_blocks = []
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(self.up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ if up_block_type == "CrossAttnUpBlock2D":
+ up_block = FlaxCrossAttnUpBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ num_layers=self.layers_per_block + 1,
+ num_attention_heads=reversed_num_attention_heads[i],
+ add_upsample=not is_final_block,
+ dropout=self.dropout,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ dtype=self.dtype,
+ )
+ else:
+ up_block = FlaxUpBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ num_layers=self.layers_per_block + 1,
+ add_upsample=not is_final_block,
+ dropout=self.dropout,
+ dtype=self.dtype,
+ )
+
+ up_blocks.append(up_block)
+ prev_output_channel = output_channel
+ self.up_blocks = up_blocks
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.conv_out = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ sample,
+ timesteps,
+ encoder_hidden_states,
+ down_block_additional_residuals=None,
+ mid_block_additional_residual=None,
+ return_dict: bool = True,
+ train: bool = False,
+ ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
+ plain tuple.
+ train (`bool`, *optional*, defaults to `False`):
+ Use deterministic functions and disable dropout when not training.
+
+ Returns:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+ """
+ # 1. time
+ if not isinstance(timesteps, jnp.ndarray):
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
+ timesteps = timesteps.astype(dtype=jnp.float32)
+ timesteps = jnp.expand_dims(timesteps, 0)
+
+ t_emb = self.time_proj(timesteps)
+ t_emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for down_block in self.down_blocks:
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+ else:
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample += down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+
+ if mid_block_additional_residual is not None:
+ sample += mid_block_additional_residual
+
+ # 5. up
+ for up_block in self.up_blocks:
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
+ down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
+ if isinstance(up_block, FlaxCrossAttnUpBlock2D):
+ sample = up_block(
+ sample,
+ temb=t_emb,
+ encoder_hidden_states=encoder_hidden_states,
+ res_hidden_states_tuple=res_samples,
+ deterministic=not train,
+ )
+ else:
+ sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = nn.silu(sample)
+ sample = self.conv_out(sample)
+ sample = jnp.transpose(sample, (0, 3, 1, 2))
+
+ if not return_dict:
+ return (sample,)
+
+ return FlaxUNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_condition_multi_branch.py b/diffusers/models/unet_2d_condition_multi_branch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4784f0cc1d15181257d9ff2e454f69a2d5cc8a5
--- /dev/null
+++ b/diffusers/models/unet_2d_condition_multi_branch.py
@@ -0,0 +1,1205 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import BaseOutput, logging
+from .activations import get_activation
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ UpBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+from safetensors import safe_open
+import safetensors
+import copy
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = 256,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ size_cond=False,
+ cond_num=5,
+ copy_last_n_block=1,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ # self.size_cond = size_cond
+ # if self.size_cond:
+ # size_embed_dim = block_out_channels[0] * 2
+ # if size_embed_dim % 2 != 0:
+ # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.")
+ # self.size_h_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # self.size_w_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # size_input_dim = size_embed_dim
+ # self.H_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+ # self.W_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(addition_time_embed_dim * 6, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ self.cond_num = cond_num
+ self.copy_last_n_block = copy_last_n_block
+ self.norm_num_groups = norm_num_groups
+ self.up_blocks_branch = nn.ModuleList([])
+ for i in range(self.cond_num):
+ copy_block_list = nn.ModuleList([])
+ for j in range(self.copy_last_n_block, 0, -1):
+ copy_block_list.append(copy.deepcopy(self.up_blocks[-j]))
+ self.up_blocks_branch.append(copy_block_list)
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ self.conv_norm_out_branch = nn.ModuleList([])
+ self.conv_act_branch = nn.ModuleList([])
+ for i in range(self.cond_num):
+ self.conv_norm_out_branch.append(copy.deepcopy(self.conv_norm_out))
+ self.conv_act_branch.append(copy.deepcopy(self.conv_act))
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+ self.conv_norm_out_branch = [None] * self.cond_num
+ self.conv_act_branch = [None] * self.cond_num
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+ self.conv_out_branch = nn.ModuleList([])
+ for i in range(self.cond_num):
+ self.conv_out_branch.append(copy.deepcopy(self.conv_out))
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False, cond_num=5, copy_last_n_block=1):
+ import os
+ import json
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors
+ model_postfix = SAFETENSORS_WEIGHTS_NAME
+ if fp16:
+ model_postfix = "diffusion_pytorch_model.fp16.safetensors"
+ else:
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+ model_postfix = WEIGHTS_NAME
+
+ config["size_cond"] = size_cond
+ if size_cond:
+ # config.addition_embed_type = "time"
+ config["addition_embed_type"] = "time"
+
+ config["cond_num"] = cond_num
+ config["copy_last_n_block"] = copy_last_n_block
+
+ model = cls.from_config(config)
+
+ if weight is None:
+ print("Loading from pretrained SD")
+ model_file = os.path.join(pretrained_model_path, model_postfix)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ # state_dict = {}
+ # with safe_open(model_file, framework="pt", device=0) as f:
+ # for k in f.keys():
+ # state_dict[k] = f.get_tensor(k)
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ for k, v in model.state_dict().items():
+ # print(k)
+ if k not in state_dict.keys():
+ state_dict.update({k: v})
+ model.load_state_dict(state_dict)
+
+ for i in range(model.cond_num):
+ for j in range(model.copy_last_n_block, 0, -1):
+ model.up_blocks_branch[i][-j] = copy.deepcopy(model.up_blocks[-j])
+ if model.norm_num_groups is not None:
+ for i in range(model.cond_num):
+ model.conv_norm_out_branch[i] = copy.deepcopy(model.conv_norm_out)
+ model.conv_act_branch[i] = copy.deepcopy(model.conv_act)
+ for i in range(model.cond_num):
+ model.conv_out_branch[i] = copy.deepcopy(model.conv_out)
+ # conv_in_weights = model.conv_in.weight.clone()
+ # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1))
+ # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # with torch.no_grad():
+ # model.conv_in.weight[:, :4] = conv_in_weights # original weights
+ # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero
+
+ return model
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ # H=None,
+ # W=None,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # if self.size_cond:
+ # h_emb = self.H_embedding(self.size_h_proj(H))
+ # w_emb = self.W_embedding(self.size_w_proj(W))
+ # emb = emb + torch.cat((h_emb, w_emb), dim=1)
+
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "time":
+ time_ids = added_cond_kwargs.get("time_ids")
+ batch_size = time_ids.shape[0]
+ # time_embeds = self.add_time_proj(time_ids)
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+
+ add_embeds = time_embeds
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks[:-self.copy_last_n_block]):
+ is_final_block = False
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ sample_list = []
+ for j in range(self.cond_num):
+ sample_copy = sample.clone()
+ down_block_res_samples_copy = down_block_res_samples + tuple()
+ for i, upsample_block in enumerate(self.up_blocks_branch[j]):
+ is_final_block = i == len(self.up_blocks_branch[j]) - 1
+
+ res_samples = down_block_res_samples_copy[-len(upsample_block.resnets) :]
+ down_block_res_samples_copy = down_block_res_samples_copy[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples_copy[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample_copy = upsample_block(
+ hidden_states=sample_copy,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample_copy = upsample_block(
+ hidden_states=sample_copy, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ sample_list.append(sample_copy)
+
+ for i, upsample_block in enumerate(self.up_blocks[-self.copy_last_n_block:]):
+ is_final_block = is_final_block = i == len(self.up_blocks[-self.copy_last_n_block:]) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 6. post-process
+ output_list = [sample]
+ for i, sample in enumerate(sample_list):
+ if self.conv_norm_out_branch[i]:
+ sample = self.conv_norm_out_branch[i](sample)
+ sample = self.conv_act_branch[i](sample)
+ sample = self.conv_out_branch[i](sample)
+ output_list.append(sample)
+ sample = torch.cat(output_list, dim=1)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_condition_multi_branch_downup.py b/diffusers/models/unet_2d_condition_multi_branch_downup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1062419d37c7704dadf791296536e0bf75d73b91
--- /dev/null
+++ b/diffusers/models/unet_2d_condition_multi_branch_downup.py
@@ -0,0 +1,1319 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import BaseOutput, logging
+from .activations import get_activation
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from .modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ UpBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+from safetensors import safe_open
+import safetensors
+import copy
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = 256,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ size_cond=False,
+ branch_num=5,
+ copy_last_n_block=1,
+ copy_first_n_block=1,
+ fusion: str = "sum",
+ off_wa=False,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ self.branch_num = branch_num
+ self.copy_last_n_block = copy_last_n_block
+ self.copy_first_n_block = copy_first_n_block
+ self.norm_num_groups = norm_num_groups
+ self.fusion = fusion
+
+ if self.fusion == "sum":
+ pass
+ elif self.fusion == "avg":
+ pass
+ elif self.fusion == "learn":
+ self.fusion_conv = nn.Conv2d(block_out_channels[self.copy_first_n_block - 1] * (self.branch_num + 1), block_out_channels[self.copy_first_n_block - 1], kernel_size=3, padding=1)
+ else:
+ assert False
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ self.conv_in_branch = nn.ModuleList([])
+ for i in range(self.branch_num):
+ self.conv_in_branch.append(copy.deepcopy(self.conv_in))
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ # self.size_cond = size_cond
+ # if self.size_cond:
+ # size_embed_dim = block_out_channels[0] * 2
+ # if size_embed_dim % 2 != 0:
+ # raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {size_embed_dim}.")
+ # self.size_h_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # self.size_w_proj = GaussianFourierProjection(
+ # size_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ # )
+ # size_input_dim = size_embed_dim
+ # self.H_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+ # self.W_embedding = TimestepEmbedding(
+ # size_input_dim,
+ # size_embed_dim,
+ # act_fn=act_fn,
+ # post_act_fn=timestep_post_act,
+ # cond_proj_dim=time_cond_proj_dim,
+ # )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ if off_wa:
+ size_cond_num = 6
+ else:
+ size_cond_num = 8
+ self.add_embedding = TimestepEmbedding(addition_time_embed_dim * size_cond_num, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.down_blocks.append(down_block)
+
+ self.down_blocks_branch = nn.ModuleList([])
+ for i in range(self.branch_num):
+ copy_block_list = nn.ModuleList([])
+ for j in range(self.copy_first_n_block):
+ copy_block_list.append(copy.deepcopy(self.down_blocks[j]))
+ self.down_blocks_branch.append(copy_block_list)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ self.up_blocks_branch = nn.ModuleList([])
+ for i in range(self.branch_num):
+ copy_block_list = nn.ModuleList([])
+ for j in range(self.copy_last_n_block, 0, -1):
+ copy_block_list.append(copy.deepcopy(self.up_blocks[-j]))
+ self.up_blocks_branch.append(copy_block_list)
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ self.conv_norm_out_branch = nn.ModuleList([])
+ self.conv_act_branch = nn.ModuleList([])
+ for i in range(self.branch_num):
+ self.conv_norm_out_branch.append(copy.deepcopy(self.conv_norm_out))
+ self.conv_act_branch.append(copy.deepcopy(self.conv_act))
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+ self.conv_norm_out_branch = [None] * self.branch_num
+ self.conv_act_branch = [None] * self.branch_num
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+ self.conv_out_branch = nn.ModuleList([])
+ for i in range(self.branch_num):
+ self.conv_out_branch.append(copy.deepcopy(self.conv_out))
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, weight=None, subfolder=None, size_cond=False, fp16=False, branch_num=5, copy_first_n_block=1, copy_last_n_block=1, fusion="sum", off_wa=False):
+ import os
+ import json
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME # diffusion_pytorch_model.safetensors
+ model_postfix = SAFETENSORS_WEIGHTS_NAME
+ if fp16:
+ model_postfix = "diffusion_pytorch_model.fp16.safetensors"
+ else:
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+ model_postfix = WEIGHTS_NAME
+
+ config["size_cond"] = size_cond
+ if size_cond:
+ # config.addition_embed_type = "time"
+ config["addition_embed_type"] = "time"
+
+ config["branch_num"] = branch_num
+ config["copy_first_n_block"] = copy_first_n_block
+ config["copy_last_n_block"] = copy_last_n_block
+ config["fusion"] = fusion
+ config["off_wa"] = off_wa
+
+ model = cls.from_config(config)
+
+ if weight is None:
+ print("Loading from pretrained SD")
+ model_file = os.path.join(pretrained_model_path, model_postfix)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ if "stable-diffusion-xl-base-1.0" in pretrained_model_path or "Realistic_Vision_V5.1_noVAE" in pretrained_model_path:
+ # state_dict = {}
+ # with safe_open(model_file, framework="pt", device=0) as f:
+ # for k in f.keys():
+ # state_dict[k] = f.get_tensor(k)
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ for k, v in model.state_dict().items():
+ # print(k)
+ if k not in state_dict.keys():
+ state_dict.update({k: v})
+ model.load_state_dict(state_dict)
+
+ for i in range(model.branch_num):
+ for j in range(model.copy_first_n_block):
+ model.down_blocks_branch[i][j] = copy.deepcopy(model.down_blocks[j])
+ for i in range(model.branch_num):
+ for j in range(model.copy_last_n_block, 0, -1):
+ model.up_blocks_branch[i][-j] = copy.deepcopy(model.up_blocks[-j])
+ if model.norm_num_groups is not None:
+ for i in range(model.branch_num):
+ model.conv_norm_out_branch[i] = copy.deepcopy(model.conv_norm_out)
+ model.conv_act_branch[i] = copy.deepcopy(model.conv_act)
+ for i in range(model.branch_num):
+ model.conv_out_branch[i] = copy.deepcopy(model.conv_out)
+ # conv_in_weights = model.conv_in.weight.clone()
+ # model.conv_in = nn.Conv2d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # # model.conv_in = InflatedConv3d(4 + 4*n_poses, conv_in_weights.shape[0], kernel_size=3, padding=(1, 1))
+ # # nn.Conv2d(4 + 4*n_poses, weights.shape[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # input noise + n poses
+ # with torch.no_grad():
+ # model.conv_in.weight[:, :4] = conv_in_weights # original weights
+ # model.conv_in.weight[:, 4:] = torch.zeros(model.conv_in.weight[:, 4:].shape) # new weights initialized to zero
+
+ return model
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_noisy_list: List[torch.FloatTensor],
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ # H=None,
+ # W=None,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # if self.size_cond:
+ # h_emb = self.H_embedding(self.size_h_proj(H))
+ # w_emb = self.W_embedding(self.size_w_proj(W))
+ # emb = emb + torch.cat((h_emb, w_emb), dim=1)
+
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "time":
+ time_ids = added_cond_kwargs.get("time_ids")
+ batch_size = time_ids.shape[0]
+ # time_embeds = self.add_time_proj(time_ids)
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+
+ add_embeds = time_embeds
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ sample_in_list = []
+ for i in range(self.branch_num):
+ sample_in_list.append(self.conv_in_branch[i](sample_noisy_list[i]))
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks[:self.copy_first_n_block]:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ down_block_res_samples_list = []
+ for i in range(self.branch_num):
+ down_block_res_samples_list.append((sample_in_list[i],))
+
+ for j in range(self.branch_num):
+ for i, downsample_block in enumerate(self.down_blocks_branch[j]):
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample_in_list[j], res_samples = downsample_block(
+ hidden_states=sample_in_list[j],
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample_in_list[j], res_samples = downsample_block(hidden_states=sample_in_list[j], temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample_in_list[j] += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples_list[j] += res_samples
+
+ # sample = (sample + sample_normal) / 2.
+ sample_list = [sample]
+ for i in range(self.branch_num):
+ sample_list.append(sample_in_list[i])
+
+ if self.fusion == "sum" or self.fusion == "avg":
+ stacked_tensor = torch.stack(sample_list, dim=0)
+ sample = torch.sum(stacked_tensor, dim=0)
+ if self.fusion == "avg":
+ sample = sample / (1 + self.branch_num)
+ elif self.fusion == "learn":
+ concat_tensor = torch.cat(sample_list, dim=1)
+ sample = self.fusion_conv(concat_tensor)
+ else:
+ assert False
+
+ for downsample_block in self.down_blocks[self.copy_first_n_block:]:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+ for i in range(self.branch_num):
+ down_block_res_samples_list[i] += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks[:-self.copy_last_n_block]):
+ is_final_block = False
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ for j in range(self.branch_num):
+ down_block_res_samples_list[j] = down_block_res_samples_list[j][: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ sample_list = []
+ for j in range(self.branch_num):
+ sample_copy = sample.clone()
+ for i, upsample_block in enumerate(self.up_blocks_branch[j]):
+ is_final_block = i == len(self.up_blocks_branch[j]) - 1
+
+ res_samples = down_block_res_samples_list[j][-len(upsample_block.resnets) :]
+ down_block_res_samples_list[j] = down_block_res_samples_list[j][: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples_list[j][-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample_copy = upsample_block(
+ hidden_states=sample_copy,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample_copy = upsample_block(
+ hidden_states=sample_copy, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ sample_list.append(sample_copy)
+
+ for i, upsample_block in enumerate(self.up_blocks[-self.copy_last_n_block:]):
+ is_final_block = i == len(self.up_blocks[-self.copy_last_n_block:]) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 6. post-process
+ output_list = [sample]
+ for i, sample_tmp in enumerate(sample_list):
+ if self.conv_norm_out_branch[i]:
+ sample_tmp = self.conv_norm_out_branch[i](sample_tmp)
+ sample_tmp = self.conv_act_branch[i](sample_tmp)
+ sample_tmp = self.conv_out_branch[i](sample_tmp)
+ output_list.append(sample_tmp)
+ sample = torch.cat(output_list, dim=1)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_3d_blocks.py b/diffusers/models/unet_3d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5c393518e2ad8edf21069dfcd417392001569d
--- /dev/null
+++ b/diffusers/models/unet_3d_blocks.py
@@ -0,0 +1,679 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+
+from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
+from .transformer_2d import Transformer2DModel
+from .transformer_temporal import TransformerTemporalModel
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ temp_convs = [
+ TemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ )
+ ]
+ attentions = []
+ temp_attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ Transformer2DModel(
+ in_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ in_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
+ for attn, temp_attn, resnet, temp_conv in zip(
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
+ ):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ temp_attentions = []
+ temp_convs = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ )
+ )
+ attentions.append(
+ Transformer2DModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for resnet, temp_conv, attn, temp_attn in zip(
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+ ):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ )[0]
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, num_frames=1):
+ output_states = ()
+
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+ attentions = []
+ temp_attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ )
+ )
+ attentions.append(
+ Transformer2DModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
+ # TODO(Patrick, William) - attention mask is not used
+ for resnet, temp_conv, attn, temp_attn in zip(
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+ ):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/diffusers/models/unet_3d_condition.py b/diffusers/models/unet_3d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2a8f1179ef9654b5234d63528468e59e371b10
--- /dev/null
+++ b/diffusers/models/unet_3d_condition.py
@@ -0,0 +1,627 @@
+# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
+# Copyright 2023 The ModelScope Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..loaders import UNet2DConditionLoadersMixin
+from ..utils import BaseOutput, logging
+from .attention_processor import AttentionProcessor, AttnProcessor
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .transformer_temporal import TransformerTemporalModel
+from .unet_3d_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet3DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*): The number of attention heads.
+ """
+
+ _supports_gradient_checkpointing = False
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1024,
+ attention_head_dim: Union[int, Tuple[int]] = 64,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise NotImplementedError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_kernel = 3
+ conv_out_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ self.transformer_in = TransformerTemporalModel(
+ num_attention_heads=8,
+ attention_head_dim=attention_head_dim,
+ in_channels=block_out_channels[0],
+ num_layers=1,
+ )
+
+ # class embedding
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=False,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=False,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=False,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+ self.conv_act = nn.SiLU()
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def disable_forward_chunking(self):
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ The [`UNet3DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+
+ Returns:
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ num_frames = sample.shape[2]
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
+
+ # 2. pre-process
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
+ sample = self.conv_in(sample)
+
+ sample = self.transformer_in(
+ sample,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ num_frames=num_frames,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = self.conv_out(sample)
+
+ # reshape to (batch, channel, framerate, width, height)
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
diff --git a/diffusers/models/vae.py b/diffusers/models/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd516dd380aa6f5888174bbd5f3df86be187feb
--- /dev/null
+++ b/diffusers/models/vae.py
@@ -0,0 +1,441 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..utils import BaseOutput, is_torch_version, randn_tensor
+from .attention_processor import SpatialNorm
+from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The decoded output sample from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # down
+ if is_torch_version(">=", "1.11.0"):
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block), sample, use_reentrant=False
+ )
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
+ )
+ else:
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
+
+ else:
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ norm_type="group", # group, spatial
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, z, latent_embeds=None):
+ sample = z
+ sample = self.conv_in(sample)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
+ )
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, latent_embeds
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ else:
+ # middle
+ sample = self.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample, latent_embeds)
+
+ # post-process
+ if latent_embeds is None:
+ sample = self.conv_norm_out(sample)
+ else:
+ sample = self.conv_norm_out(sample, latent_embeds)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(
+ self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.vq_embed_dim = vq_embed_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.vq_embed_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
+
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
diff --git a/diffusers/models/vae_flax.py b/diffusers/models/vae_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8f5b1d0e399ab8e58d81d396d19b6f082192f5a
--- /dev/null
+++ b/diffusers/models/vae_flax.py
@@ -0,0 +1,869 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
+
+import math
+from functools import partial
+from typing import Tuple
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+
+from ..configuration_utils import ConfigMixin, flax_register_to_config
+from ..utils import BaseOutput
+from .modeling_flax_utils import FlaxModelMixin
+
+
+@flax.struct.dataclass
+class FlaxDecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ The decoded output sample from the last layer of the model.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ The `dtype` of the parameters.
+ """
+
+ sample: jnp.ndarray
+
+
+@flax.struct.dataclass
+class FlaxAutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`FlaxDiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
+ `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "FlaxDiagonalGaussianDistribution"
+
+
+class FlaxUpsample2D(nn.Module):
+ """
+ Flax implementation of 2D Upsample layer
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.in_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ batch, height, width, channels = hidden_states.shape
+ hidden_states = jax.image.resize(
+ hidden_states,
+ shape=(batch, height * 2, width * 2, channels),
+ method="nearest",
+ )
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxDownsample2D(nn.Module):
+ """
+ Flax implementation of 2D Downsample layer
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.in_channels,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
+ hidden_states = jnp.pad(hidden_states, pad_width=pad)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxResnetBlock2D(nn.Module):
+ """
+ Flax implementation of 2D Resnet Block.
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ out_channels (`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm.
+ use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
+ Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ out_channels: int = None
+ dropout: float = 0.0
+ groups: int = 32
+ use_nin_shortcut: bool = None
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
+ self.conv1 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
+ self.dropout_layer = nn.Dropout(self.dropout)
+ self.conv2 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
+
+ self.conv_shortcut = None
+ if use_nin_shortcut:
+ self.conv_shortcut = nn.Conv(
+ out_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states, deterministic=True):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class FlaxAttentionBlock(nn.Module):
+ r"""
+ Flax Convolutional based multi-head attention block for diffusion-based VAE.
+
+ Parameters:
+ channels (:obj:`int`):
+ Input channels
+ num_head_channels (:obj:`int`, *optional*, defaults to `None`):
+ Number of attention heads
+ num_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+
+ """
+ channels: int
+ num_head_channels: int = None
+ num_groups: int = 32
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
+
+ dense = partial(nn.Dense, self.channels, dtype=self.dtype)
+
+ self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
+ self.query, self.key, self.value = dense(), dense(), dense()
+ self.proj_attn = dense()
+
+ def transpose_for_scores(self, projection):
+ new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
+ new_projection = projection.reshape(new_projection_shape)
+ # (B, T, H, D) -> (B, H, T, D)
+ new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
+ return new_projection
+
+ def __call__(self, hidden_states):
+ residual = hidden_states
+ batch, height, width, channels = hidden_states.shape
+
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.reshape((batch, height * width, channels))
+
+ query = self.query(hidden_states)
+ key = self.key(hidden_states)
+ value = self.value(hidden_states)
+
+ # transpose
+ query = self.transpose_for_scores(query)
+ key = self.transpose_for_scores(key)
+ value = self.transpose_for_scores(value)
+
+ # compute attentions
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+ attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
+ attn_weights = nn.softmax(attn_weights, axis=-1)
+
+ # attend to values
+ hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
+
+ hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
+ new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
+ hidden_states = hidden_states.reshape(new_hidden_states_shape)
+
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.reshape((batch, height, width, channels))
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+class FlaxDownEncoderBlock2D(nn.Module):
+ r"""
+ Flax Resnet blocks-based Encoder block for diffusion-based VAE.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsample layer
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ add_downsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout=self.dropout,
+ groups=self.resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+ self.resnets = resnets
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUpDecoderBlock2D(nn.Module):
+ r"""
+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add upsample layer
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ add_upsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout=self.dropout,
+ groups=self.resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUNetMidBlock2D(nn.Module):
+ r"""
+ Flax Unet Mid-Block module.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet and Attention block group norm
+ num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
+ Number of attention heads for each attention block
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ num_attention_heads: int = 1
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout=self.dropout,
+ groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ ]
+
+ attentions = []
+
+ for _ in range(self.num_layers):
+ attn_block = FlaxAttentionBlock(
+ channels=self.in_channels,
+ num_head_channels=self.num_attention_heads,
+ num_groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout=self.dropout,
+ groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ return hidden_states
+
+
+class FlaxEncoder(nn.Module):
+ r"""
+ Flax Implementation of VAE Encoder.
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (:obj:`int`, *optional*, defaults to 3):
+ Input channels
+ out_channels (:obj:`int`, *optional*, defaults to 3):
+ Output channels
+ down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
+ DownEncoder block type
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple containing the number of output channels for each block
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
+ Number of Resnet layer for each block
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
+ norm num group
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
+ Activation function
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
+ Whether to double the last output channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
+ block_out_channels: Tuple[int] = (64,)
+ layers_per_block: int = 2
+ norm_num_groups: int = 32
+ act_fn: str = "silu"
+ double_z: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+ # in
+ self.conv_in = nn.Conv(
+ block_out_channels[0],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # downsampling
+ down_blocks = []
+ output_channel = block_out_channels[0]
+ for i, _ in enumerate(self.down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = FlaxDownEncoderBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=self.layers_per_block,
+ resnet_groups=self.norm_num_groups,
+ add_downsample=not is_final_block,
+ dtype=self.dtype,
+ )
+ down_blocks.append(down_block)
+ self.down_blocks = down_blocks
+
+ # middle
+ self.mid_block = FlaxUNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ num_attention_heads=None,
+ dtype=self.dtype,
+ )
+
+ # end
+ conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
+ self.conv_out = nn.Conv(
+ conv_out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, sample, deterministic: bool = True):
+ # in
+ sample = self.conv_in(sample)
+
+ # downsampling
+ for block in self.down_blocks:
+ sample = block(sample, deterministic=deterministic)
+
+ # middle
+ sample = self.mid_block(sample, deterministic=deterministic)
+
+ # end
+ sample = self.conv_norm_out(sample)
+ sample = nn.swish(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class FlaxDecoder(nn.Module):
+ r"""
+ Flax Implementation of VAE Decoder.
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (:obj:`int`, *optional*, defaults to 3):
+ Input channels
+ out_channels (:obj:`int`, *optional*, defaults to 3):
+ Output channels
+ up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
+ UpDecoder block type
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple containing the number of output channels for each block
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
+ Number of Resnet layer for each block
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
+ norm num group
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
+ Activation function
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
+ Whether to double the last output channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ parameters `dtype`
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
+ block_out_channels: int = (64,)
+ layers_per_block: int = 2
+ norm_num_groups: int = 32
+ act_fn: str = "silu"
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+
+ # z to block_in
+ self.conv_in = nn.Conv(
+ block_out_channels[-1],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # middle
+ self.mid_block = FlaxUNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ num_attention_heads=None,
+ dtype=self.dtype,
+ )
+
+ # upsampling
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ up_blocks = []
+ for i, _ in enumerate(self.up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = FlaxUpDecoderBlock2D(
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ num_layers=self.layers_per_block + 1,
+ resnet_groups=self.norm_num_groups,
+ add_upsample=not is_final_block,
+ dtype=self.dtype,
+ )
+ up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ self.up_blocks = up_blocks
+
+ # end
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
+ self.conv_out = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, sample, deterministic: bool = True):
+ # z to block_in
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample, deterministic=deterministic)
+
+ # upsampling
+ for block in self.up_blocks:
+ sample = block(sample, deterministic=deterministic)
+
+ sample = self.conv_norm_out(sample)
+ sample = nn.swish(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class FlaxDiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ # Last axis to account for channels-last
+ self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
+ self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = jnp.exp(0.5 * self.logvar)
+ self.var = jnp.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = jnp.zeros_like(self.mean)
+
+ def sample(self, key):
+ return self.mean + self.std * jax.random.normal(key, self.mean.shape)
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return jnp.array([0.0])
+
+ if other is None:
+ return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
+
+ return 0.5 * jnp.sum(
+ jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ axis=[1, 2, 3],
+ )
+
+ def nll(self, sample, axis=[1, 2, 3]):
+ if self.deterministic:
+ return jnp.array([0.0])
+
+ logtwopi = jnp.log(2.0 * jnp.pi)
+ return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
+
+ def mode(self):
+ return self.mean
+
+
+@flax_register_to_config
+class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
+ r"""
+ Flax implementation of a VAE model with KL loss for decoding latent representations.
+
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
+ implemented for all models (such as downloading or saving).
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
+ general usage and behavior.
+
+ Inherent JAX features such as the following are supported:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ out_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`):
+ Number of ResNet layer for each block.
+ act_fn (`str`, *optional*, defaults to `silu`):
+ The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`):
+ Number of channels in the latent space.
+ norm_num_groups (`int`, *optional*, defaults to `32`):
+ The number of groups for normalization.
+ sample_size (`int`, *optional*, defaults to 32):
+ Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ The `dtype` of the parameters.
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
+ block_out_channels: Tuple[int] = (64,)
+ layers_per_block: int = 1
+ act_fn: str = "silu"
+ latent_channels: int = 4
+ norm_num_groups: int = 32
+ sample_size: int = 32
+ scaling_factor: float = 0.18215
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.encoder = FlaxEncoder(
+ in_channels=self.config.in_channels,
+ out_channels=self.config.latent_channels,
+ down_block_types=self.config.down_block_types,
+ block_out_channels=self.config.block_out_channels,
+ layers_per_block=self.config.layers_per_block,
+ act_fn=self.config.act_fn,
+ norm_num_groups=self.config.norm_num_groups,
+ double_z=True,
+ dtype=self.dtype,
+ )
+ self.decoder = FlaxDecoder(
+ in_channels=self.config.latent_channels,
+ out_channels=self.config.out_channels,
+ up_block_types=self.config.up_block_types,
+ block_out_channels=self.config.block_out_channels,
+ layers_per_block=self.config.layers_per_block,
+ norm_num_groups=self.config.norm_num_groups,
+ act_fn=self.config.act_fn,
+ dtype=self.dtype,
+ )
+ self.quant_conv = nn.Conv(
+ 2 * self.config.latent_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+ self.post_quant_conv = nn.Conv(
+ self.config.latent_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
+ # init input tensors
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
+
+ params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
+ rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
+
+ return self.init(rngs, sample)["params"]
+
+ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
+
+ hidden_states = self.encoder(sample, deterministic=deterministic)
+ moments = self.quant_conv(hidden_states)
+ posterior = FlaxDiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return FlaxAutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
+ if latents.shape[-1] != self.config.latent_channels:
+ latents = jnp.transpose(latents, (0, 2, 3, 1))
+
+ hidden_states = self.post_quant_conv(latents)
+ hidden_states = self.decoder(hidden_states, deterministic=deterministic)
+
+ hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return FlaxDecoderOutput(sample=hidden_states)
+
+ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
+ posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
+ if sample_posterior:
+ rng = self.make_rng("gaussian")
+ hidden_states = posterior.latent_dist.sample(rng)
+ else:
+ hidden_states = posterior.latent_dist.mode()
+
+ sample = self.decode(hidden_states, return_dict=return_dict).sample
+
+ if not return_dict:
+ return (sample,)
+
+ return FlaxDecoderOutput(sample=sample)
diff --git a/diffusers/models/vq_model.py b/diffusers/models/vq_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..687449e8c7557473c0af994b30ef4c7dfba9718c
--- /dev/null
+++ b/diffusers/models/vq_model.py
@@ -0,0 +1,167 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, apply_forward_hook
+from .modeling_utils import ModelMixin
+from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The encoded output sample from the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""
+ A VQ-VAE model for decoding latent representations.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
+ scaling_factor (`float`, *optional*, defaults to `0.18215`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ norm_num_groups: int = 32,
+ vq_embed_dim: Optional[int] = None,
+ scaling_factor: float = 0.18215,
+ norm_type: str = "group", # group, spatial
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=False,
+ )
+
+ vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
+
+ self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
+ self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
+ self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_type=norm_type,
+ )
+
+ @apply_forward_hook
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ @apply_forward_hook
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant2 = self.post_quant_conv(quant)
+ dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ The [`VQModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vq_model.VQEncoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
+ is returned.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/diffusers/optimization.py b/diffusers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e6125a0f5565b80ced30dfc147f8168ef35a5c
--- /dev/null
+++ b/diffusers/optimization.py
@@ -0,0 +1,354 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SchedulerType(Enum):
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+ PIECEWISE_CONSTANT = "piecewise_constant"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ step_rules (`string`):
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
+ steps and multiple 0.005 for the other steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ rules_dict = {}
+ rule_list = step_rules.split(",")
+ for rule_str in rule_list[:-1]:
+ value_str, steps_str = rule_str.split(":")
+ steps = int(steps_str)
+ value = float(value_str)
+ rules_dict[steps] = value
+ last_lr_multiple = float(rule_list[-1])
+
+ def create_rules_function(rules_dict, last_lr_multiple):
+ def rule_func(steps: int) -> float:
+ sorted_steps = sorted(rules_dict.keys())
+ for i, sorted_step in enumerate(sorted_steps):
+ if steps < sorted_step:
+ return rules_dict[sorted_steps[i]]
+ return last_lr_multiple
+
+ return rule_func
+
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
+
+ return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_periods (`float`, *optional*, defaults to 0.5):
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
+ value to 0 following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+ SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ step_rules: Optional[str] = None,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ num_cycles: int = 1,
+ power: float = 1.0,
+ last_epoch: int = -1,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ step_rules (`str`, *optional*):
+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_cycles (`int`, *optional*):
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor. See `POLYNOMIAL` scheduler
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer, last_epoch=last_epoch)
+
+ if name == SchedulerType.PIECEWISE_CONSTANT:
+ return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ last_epoch=last_epoch,
+ )
+
+ if name == SchedulerType.POLYNOMIAL:
+ return schedule_func(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ power=power,
+ last_epoch=last_epoch,
+ )
+
+ return schedule_func(
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
+ )
diff --git a/diffusers/pipeline_utils.py b/diffusers/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87709d5f616cdfb195ed4527e4b630a86136c29c
--- /dev/null
+++ b/diffusers/pipeline_utils.py
@@ -0,0 +1,29 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+# NOTE: This file is deprecated and will be removed in a future version.
+# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
+
+from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
+from .utils import deprecate
+
+
+deprecate(
+ "pipelines_utils",
+ "0.22.0",
+ "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
+ standard_warn=False,
+ stacklevel=3,
+)
diff --git a/diffusers/pipelines/README.md b/diffusers/pipelines/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7562040596e9028ed56431817f42f4379ecf3435
--- /dev/null
+++ b/diffusers/pipelines/README.md
@@ -0,0 +1,171 @@
+# 🧨 Diffusers Pipelines
+
+Pipelines provide a simple way to run state-of-the-art diffusion models in inference.
+Most diffusion systems consist of multiple independently-trained models and highly adaptable scheduler
+components - all of which are needed to have a functioning end-to-end diffusion system.
+
+As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models:
+- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392)
+- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12)
+- [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel)
+- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py),
+- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor),
+- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py).
+All of these components are necessary to run stable diffusion in inference even though they were trained
+or created independently from each other.
+
+To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
+More specifically, we strive to provide pipelines that
+- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
+- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
+- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
+- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
+
+**Note** that pipelines do not (and should not) offer any training functionality.
+If you are looking for *official* training examples, please have a look at [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+
+
+## Pipelines Summary
+
+The following table summarizes all officially supported pipelines, their corresponding paper, and if
+available a colab notebook to directly try them out.
+
+| Pipeline | Source | Tasks | Colab
+|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:|
+| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* |
+| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* |
+| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* |
+| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* |
+| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
+| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
+| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
+| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
+
+**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
+However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below.
+
+## Pipelines API
+
+Diffusion models often consist of multiple independently-trained models or other previously existing components.
+
+
+Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one.
+During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality:
+
+- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.*
+"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be
+loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`.
+- [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`.
+In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated
+from the local path.
+- [`to`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L118) which accepts a `string` or `torch.device` to move all models that are of type `torch.nn.Module` to the passed device. The behavior is fully analogous to [PyTorch's `to` method](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to).
+- [`__call__`] method to use the pipeline in inference. `__call__` defines inference logic of the pipeline and should ideally encompass all aspects of it, from pre-processing to forwarding tensors to the different models and schedulers, as well as post-processing. The API of the `__call__` method can strongly vary from pipeline to pipeline. *E.g.* a text-to-image pipeline, such as [`StableDiffusionPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) should accept among other things the text prompt to generate the image. A pure image generation pipeline, such as [DDPMPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/ddpm) on the other hand can be run without providing any inputs. To better understand what inputs can be adapted for
+each pipeline, one should look directly into the respective pipeline.
+
+**Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should
+not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community)
+
+## Contribution
+
+We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
+all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
+
+- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
+- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
+use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
+logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.
+- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better.
+- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*.
+
+## Examples
+
+### Text-to-Image generation with Stable Diffusion
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
+
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Image-to-Image text-guided generation with Stable Diffusion
+
+The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
+
+```python
+import requests
+from PIL import Image
+from io import BytesIO
+
+from diffusers import StableDiffusionImg2ImgPipeline
+
+# load the pipeline
+device = "cuda"
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to(device)
+
+# let's download an initial image
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((768, 512))
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+images[0].save("fantasy_landscape.png")
+```
+You can also run this example on colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+
+### Tweak prompts reusing seeds and latents
+
+You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
+
+
+### In-painting using Stable Diffusion
+
+The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt.
+
+```python
+import PIL
+import requests
+import torch
+from io import BytesIO
+
+from diffusers import StableDiffusionInpaintPipeline
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+You can also run this example on colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
diff --git a/diffusers/pipelines/__init__.py b/diffusers/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..802ae4f5bc9490aa247812d65a4dc17606747fea
--- /dev/null
+++ b/diffusers/pipelines/__init__.py
@@ -0,0 +1,188 @@
+from ..utils import (
+ OptionalDependencyNotAvailable,
+ is_flax_available,
+ is_invisible_watermark_available,
+ is_k_diffusion_available,
+ is_librosa_available,
+ is_note_seq_available,
+ is_onnx_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_pt_objects import * # noqa F403
+else:
+ from .consistency_models import ConsistencyModelPipeline
+ from .dance_diffusion import DanceDiffusionPipeline
+ from .ddim import DDIMPipeline
+ from .ddpm import DDPMPipeline
+ from .dit import DiTPipeline
+ from .latent_diffusion import LDMSuperResolutionPipeline
+ from .latent_diffusion_uncond import LDMPipeline
+ from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
+ from .pndm import PNDMPipeline
+ from .repaint import RePaintPipeline
+ from .score_sde_ve import ScoreSdeVePipeline
+ from .stochastic_karras_ve import KarrasVePipeline
+
+try:
+ if not (is_torch_available() and is_librosa_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
+else:
+ from .audio_diffusion import AudioDiffusionPipeline, Mel
+
+try:
+ if not (is_torch_available() and is_transformers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
+ from .audioldm import AudioLDMPipeline
+ from .controlnet import (
+ StableDiffusionControlNetImg2ImgPipeline,
+ StableDiffusionControlNetInpaintPipeline,
+ StableDiffusionControlNetPipeline,
+ )
+ from .deepfloyd_if import (
+ IFImg2ImgPipeline,
+ IFImg2ImgSuperResolutionPipeline,
+ IFInpaintingPipeline,
+ IFInpaintingSuperResolutionPipeline,
+ IFPipeline,
+ IFSuperResolutionPipeline,
+ )
+ from .kandinsky import (
+ KandinskyImg2ImgPipeline,
+ KandinskyInpaintPipeline,
+ KandinskyPipeline,
+ KandinskyPriorPipeline,
+ )
+ from .kandinsky2_2 import (
+ KandinskyV22ControlnetImg2ImgPipeline,
+ KandinskyV22ControlnetPipeline,
+ KandinskyV22Img2ImgPipeline,
+ KandinskyV22InpaintPipeline,
+ KandinskyV22Pipeline,
+ KandinskyV22PriorEmb2EmbPipeline,
+ KandinskyV22PriorPipeline,
+ )
+ from .latent_diffusion import LDMTextToImagePipeline
+ from .paint_by_example import PaintByExamplePipeline
+ from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
+ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
+ from .stable_diffusion import (
+ CycleDiffusionPipeline,
+ StableDiffusionAttendAndExcitePipeline,
+ StableDiffusionDepth2ImgPipeline,
+ StableDiffusionDiffEditPipeline,
+ StableDiffusionImageVariationPipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionInstructPix2PixPipeline,
+ StableDiffusionLatentUpscalePipeline,
+ StableDiffusionLDM3DPipeline,
+ StableDiffusionModelEditingPipeline,
+ StableDiffusionPanoramaPipeline,
+ StableDiffusionParadigmsPipeline,
+ StableDiffusionPipeline,
+ StableDiffusionPix2PixZeroPipeline,
+ StableDiffusionSAGPipeline,
+ StableDiffusionUpscalePipeline,
+ StableUnCLIPImg2ImgPipeline,
+ StableUnCLIPPipeline,
+ )
+ from .stable_diffusion_safe import StableDiffusionPipelineSafe
+ from .t2i_adapter import StableDiffusionAdapterPipeline
+ from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
+ from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
+ from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
+ from .versatile_diffusion import (
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ )
+ from .vq_diffusion import VQDiffusionPipeline
+
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
+else:
+ from .controlnet import StableDiffusionXLControlNetPipeline
+ from .stable_diffusion_xl import (
+ StableDiffusionXLImg2ImgPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLPipeline,
+ )
+
+try:
+ if not is_onnx_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_onnx_objects import * # noqa F403
+else:
+ from .onnx_utils import OnnxRuntimeModel
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
+else:
+ from .stable_diffusion import (
+ OnnxStableDiffusionImg2ImgPipeline,
+ OnnxStableDiffusionInpaintPipeline,
+ OnnxStableDiffusionInpaintPipelineLegacy,
+ OnnxStableDiffusionPipeline,
+ OnnxStableDiffusionUpscalePipeline,
+ StableDiffusionOnnxPipeline,
+ )
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
+else:
+ from .stable_diffusion import StableDiffusionKDiffusionPipeline
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_flax_objects import * # noqa F403
+else:
+ from .pipeline_flax_utils import FlaxDiffusionPipeline
+
+
+try:
+ if not (is_flax_available() and is_transformers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
+else:
+ from .controlnet import FlaxStableDiffusionControlNetPipeline
+ from .stable_diffusion import (
+ FlaxStableDiffusionImg2ImgPipeline,
+ FlaxStableDiffusionInpaintPipeline,
+ FlaxStableDiffusionPipeline,
+ )
+try:
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
+else:
+ from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
diff --git a/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c33b6696d01f832459cffd982a75a8e8018cc064
Binary files /dev/null and b/diffusers/pipelines/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ceb2e953d66aeb81ecf4ec660b54414ba00852e0
Binary files /dev/null and b/diffusers/pipelines/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..193100f53deab1189c140af56c2a982a0ac1d4b9
Binary files /dev/null and b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-310.pyc differ
diff --git a/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46c4aeee1e5f0cf7109de40874e96de6b455eb7f
Binary files /dev/null and b/diffusers/pipelines/__pycache__/pipeline_utils.cpython-38.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__init__.py b/diffusers/pipelines/alt_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dab2d8db1045ef27ff5d2234951c1488f547401b
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/__init__.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
+class AltDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Alt Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+if is_transformers_available() and is_torch_available():
+ from .modeling_roberta_series import RobertaSeriesModelWithTransformation
+ from .pipeline_alt_diffusion import AltDiffusionPipeline
+ from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43e61607e3f39a7092ce2a1a3ab49b2b9f96aad7
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b74bf23dae3832663a8a3b35377b539e45879acc
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9b5ba958c7db1c78799ecd0cc671ea60255bcbe
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d2ef8a8ef2ece127bad886d836132a35d459127
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-38.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b23733118c8bbd0b8b46100cf1ab1bb847266cd2
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55f78c2677e7e6ce3a99850e3d3942ba8f16fb60
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0e33f8242771561ac1c510f427890eceaf433cc
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65c7c463e041b7d207e30e37555e5007570342de
Binary files /dev/null and b/diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
new file mode 100644
index 0000000000000000000000000000000000000000..f73ef15d7de7948a9cbad246027ca71f4a6db198
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
+from transformers.utils import ModelOutput
+
+
+@dataclass
+class TransformationModelOutput(ModelOutput):
+ """
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ projection_state: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(
+ self,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ project_dim=512,
+ pooler_fn="cls",
+ learn_encoder=False,
+ use_attention_mask=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+ self.use_attention_mask = use_attention_mask
+
+
+class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ base_model_prefix = "roberta"
+ config_class = RobertaSeriesConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size, config.project_dim)
+ self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
+ if self.has_pre_transformation:
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ):
+ r""" """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.base_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.has_pre_transformation:
+ sequence_output2 = outputs["hidden_states"][-2]
+ sequence_output2 = self.pre_LN(sequence_output2)
+ projection_state2 = self.transformation_pre(sequence_output2)
+
+ return TransformationModelOutput(
+ projection_state=projection_state2,
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ projection_state = self.transformation(outputs.last_hidden_state)
+ return TransformationModelOutput(
+ projection_state=projection_state,
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..507c082d93638b594e8be21607eae05f471aecfa
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -0,0 +1,732 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, XLMRobertaTokenizer
+
+from diffusers.utils import is_accelerate_available, is_accelerate_version
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AltDiffusionPipeline
+
+ >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap"
+ >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ warnings.warn(
+ (
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead"
+ ),
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6f297122ba4c93127fccf4b9c1ace9a90092bbd
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -0,0 +1,758 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, XLMRobertaTokenizer
+
+from diffusers.utils import is_accelerate_available, is_accelerate_version
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from diffusers import AltDiffusionImg2ImgPipeline
+
+ >>> device = "cuda"
+ >>> model_id_or_path = "BAAI/AltDiffusion-m9"
+ >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ >>> response = requests.get(url)
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_image = init_image.resize((768, 512))
+
+ >>> # "A fantasy landscape, trending on artstation"
+ >>> prompt = "幻想风景, artstation"
+
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+ >>> images[0].save("幻想风景.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionImg2ImgPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image to image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ warnings.warn(
+ (
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead"
+ ),
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective"
+ f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/audio_diffusion/__init__.py b/diffusers/pipelines/audio_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..58554c45ea52b9897293217652db36fdace7549f
--- /dev/null
+++ b/diffusers/pipelines/audio_diffusion/__init__.py
@@ -0,0 +1,2 @@
+from .mel import Mel
+from .pipeline_audio_diffusion import AudioDiffusionPipeline
diff --git a/diffusers/pipelines/audio_diffusion/mel.py b/diffusers/pipelines/audio_diffusion/mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bf28fd25a5a5d39416eaf6bfd76b7f6945f4b19
--- /dev/null
+++ b/diffusers/pipelines/audio_diffusion/mel.py
@@ -0,0 +1,160 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np # noqa: E402
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...schedulers.scheduling_utils import SchedulerMixin
+
+
+try:
+ import librosa # noqa: E402
+
+ _librosa_can_be_imported = True
+ _import_error = ""
+except Exception as e:
+ _librosa_can_be_imported = False
+ _import_error = (
+ f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it."
+ )
+
+
+from PIL import Image # noqa: E402
+
+
+class Mel(ConfigMixin, SchedulerMixin):
+ """
+ Parameters:
+ x_res (`int`): x resolution of spectrogram (time)
+ y_res (`int`): y resolution of spectrogram (frequency bins)
+ sample_rate (`int`): sample rate of audio
+ n_fft (`int`): number of Fast Fourier Transforms
+ hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
+ top_db (`int`): loudest in decibels
+ n_iter (`int`): number of iterations for Griffin Linn mel inversion
+ """
+
+ config_name = "mel_config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ x_res: int = 256,
+ y_res: int = 256,
+ sample_rate: int = 22050,
+ n_fft: int = 2048,
+ hop_length: int = 512,
+ top_db: int = 80,
+ n_iter: int = 32,
+ ):
+ self.hop_length = hop_length
+ self.sr = sample_rate
+ self.n_fft = n_fft
+ self.top_db = top_db
+ self.n_iter = n_iter
+ self.set_resolution(x_res, y_res)
+ self.audio = None
+
+ if not _librosa_can_be_imported:
+ raise ValueError(_import_error)
+
+ def set_resolution(self, x_res: int, y_res: int):
+ """Set resolution.
+
+ Args:
+ x_res (`int`): x resolution of spectrogram (time)
+ y_res (`int`): y resolution of spectrogram (frequency bins)
+ """
+ self.x_res = x_res
+ self.y_res = y_res
+ self.n_mels = self.y_res
+ self.slice_size = self.x_res * self.hop_length - 1
+
+ def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
+ """Load audio.
+
+ Args:
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
+ raw_audio (`np.ndarray`): audio as numpy array
+ """
+ if audio_file is not None:
+ self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
+ else:
+ self.audio = raw_audio
+
+ # Pad with silence if necessary.
+ if len(self.audio) < self.x_res * self.hop_length:
+ self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
+
+ def get_number_of_slices(self) -> int:
+ """Get number of slices in audio.
+
+ Returns:
+ `int`: number of spectograms audio can be sliced into
+ """
+ return len(self.audio) // self.slice_size
+
+ def get_audio_slice(self, slice: int = 0) -> np.ndarray:
+ """Get slice of audio.
+
+ Args:
+ slice (`int`): slice number of audio (out of get_number_of_slices())
+
+ Returns:
+ `np.ndarray`: audio as numpy array
+ """
+ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
+
+ def get_sample_rate(self) -> int:
+ """Get sample rate:
+
+ Returns:
+ `int`: sample rate of audio
+ """
+ return self.sr
+
+ def audio_slice_to_image(self, slice: int) -> Image.Image:
+ """Convert slice of audio to spectrogram.
+
+ Args:
+ slice (`int`): slice number of audio to convert (out of get_number_of_slices())
+
+ Returns:
+ `PIL Image`: grayscale image of x_res x y_res
+ """
+ S = librosa.feature.melspectrogram(
+ y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
+ )
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
+ bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
+ image = Image.fromarray(bytedata)
+ return image
+
+ def image_to_audio(self, image: Image.Image) -> np.ndarray:
+ """Converts spectrogram to audio.
+
+ Args:
+ image (`PIL Image`): x_res x y_res grayscale image
+
+ Returns:
+ audio (`np.ndarray`): raw audio
+ """
+ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
+ log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
+ S = librosa.db_to_power(log_S)
+ audio = librosa.feature.inverse.mel_to_audio(
+ S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
+ )
+ return audio
diff --git a/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..629a2e7d32ca307c91b55359ccd93c8fb12884ff
--- /dev/null
+++ b/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py
@@ -0,0 +1,249 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from math import acos, sin
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import DDIMScheduler, DDPMScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
+from .mel import Mel
+
+
+class AudioDiffusionPipeline(DiffusionPipeline):
+ """
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
+ unet ([`UNet2DConditionModel`]): UNET model
+ mel ([`Mel`]): transform audio <-> spectrogram
+ scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
+ """
+
+ _optional_components = ["vqvae"]
+
+ def __init__(
+ self,
+ vqvae: AutoencoderKL,
+ unet: UNet2DConditionModel,
+ mel: Mel,
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
+ ):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
+
+ def get_default_steps(self) -> int:
+ """Returns default number of steps recommended for inference
+
+ Returns:
+ `int`: number of steps
+ """
+ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ audio_file: str = None,
+ raw_audio: np.ndarray = None,
+ slice: int = 0,
+ start_step: int = 0,
+ steps: int = None,
+ generator: torch.Generator = None,
+ mask_start_secs: float = 0,
+ mask_end_secs: float = 0,
+ step_generator: torch.Generator = None,
+ eta: float = 0,
+ noise: torch.Tensor = None,
+ encoding: torch.Tensor = None,
+ return_dict=True,
+ ) -> Union[
+ Union[AudioPipelineOutput, ImagePipelineOutput],
+ Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
+ ]:
+ """Generate random mel spectrogram from audio input and convert to audio.
+
+ Args:
+ batch_size (`int`): number of samples to generate
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
+ raw_audio (`np.ndarray`): audio as numpy array
+ slice (`int`): slice number of audio to convert
+ start_step (int): step to start from
+ steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
+ generator (`torch.Generator`): random number generator or None
+ mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
+ mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
+ step_generator (`torch.Generator`): random number generator used to de-noise or None
+ eta (`float`): parameter between 0 and 1 used with DDIM scheduler
+ noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
+ encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
+ return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
+
+ Returns:
+ `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
+ """
+
+ steps = steps or self.get_default_steps()
+ self.scheduler.set_timesteps(steps)
+ step_generator = step_generator or generator
+ # For backwards compatibility
+ if type(self.unet.config.sample_size) == int:
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
+ if noise is None:
+ noise = randn_tensor(
+ (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size[0],
+ self.unet.config.sample_size[1],
+ ),
+ generator=generator,
+ device=self.device,
+ )
+ images = noise
+ mask = None
+
+ if audio_file is not None or raw_audio is not None:
+ self.mel.load_audio(audio_file, raw_audio)
+ input_image = self.mel.audio_slice_to_image(slice)
+ input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
+ (input_image.height, input_image.width)
+ )
+ input_image = (input_image / 255) * 2 - 1
+ input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
+
+ if self.vqvae is not None:
+ input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
+ generator=generator
+ )[0]
+ input_images = self.vqvae.config.scaling_factor * input_images
+
+ if start_step > 0:
+ images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
+
+ pixels_per_second = (
+ self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
+ )
+ mask_start = int(mask_start_secs * pixels_per_second)
+ mask_end = int(mask_end_secs * pixels_per_second)
+ mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
+
+ for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
+ if isinstance(self.unet, UNet2DConditionModel):
+ model_output = self.unet(images, t, encoding)["sample"]
+ else:
+ model_output = self.unet(images, t)["sample"]
+
+ if isinstance(self.scheduler, DDIMScheduler):
+ images = self.scheduler.step(
+ model_output=model_output,
+ timestep=t,
+ sample=images,
+ eta=eta,
+ generator=step_generator,
+ )["prev_sample"]
+ else:
+ images = self.scheduler.step(
+ model_output=model_output,
+ timestep=t,
+ sample=images,
+ generator=step_generator,
+ )["prev_sample"]
+
+ if mask is not None:
+ if mask_start > 0:
+ images[:, :, :, :mask_start] = mask[:, step, :, :mask_start]
+ if mask_end > 0:
+ images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
+
+ if self.vqvae is not None:
+ # 0.18215 was scaling factor used in training to ensure unit variance
+ images = 1 / self.vqvae.config.scaling_factor * images
+ images = self.vqvae.decode(images)["sample"]
+
+ images = (images / 2 + 0.5).clamp(0, 1)
+ images = images.cpu().permute(0, 2, 3, 1).numpy()
+ images = (images * 255).round().astype("uint8")
+ images = list(
+ (Image.fromarray(_[:, :, 0]) for _ in images)
+ if images.shape[3] == 1
+ else (Image.fromarray(_, mode="RGB").convert("L") for _ in images)
+ )
+
+ audios = [self.mel.image_to_audio(_) for _ in images]
+ if not return_dict:
+ return images, (self.mel.get_sample_rate(), audios)
+
+ return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
+
+ @torch.no_grad()
+ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
+ """Reverse step process: recover noisy image from generated image.
+
+ Args:
+ images (`List[PIL Image]`): list of images to encode
+ steps (`int`): number of encoding steps to perform (defaults to 50)
+
+ Returns:
+ `np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
+ """
+
+ # Only works with DDIM as this method is deterministic
+ assert isinstance(self.scheduler, DDIMScheduler)
+ self.scheduler.set_timesteps(steps)
+ sample = np.array(
+ [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images]
+ )
+ sample = (sample / 255) * 2 - 1
+ sample = torch.Tensor(sample).to(self.device)
+
+ for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
+ prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
+ alpha_prod_t_prev = (
+ self.scheduler.alphas_cumprod[prev_timestep]
+ if prev_timestep >= 0
+ else self.scheduler.final_alpha_cumprod
+ )
+ beta_prod_t = 1 - alpha_prod_t
+ model_output = self.unet(sample, t)["sample"]
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
+ sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5)
+ sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output
+
+ return sample
+
+ @staticmethod
+ def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor:
+ """Spherical Linear intERPolation
+
+ Args:
+ x0 (`torch.Tensor`): first tensor to interpolate between
+ x1 (`torch.Tensor`): seconds tensor to interpolate between
+ alpha (`float`): interpolation between 0 and 1
+
+ Returns:
+ `torch.Tensor`: interpolated tensor
+ """
+
+ theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
+ return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
diff --git a/diffusers/pipelines/audioldm/__init__.py b/diffusers/pipelines/audioldm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ddef6c3f3253afd1f59c14b685a5d14d7622150
--- /dev/null
+++ b/diffusers/pipelines/audioldm/__init__.py
@@ -0,0 +1,17 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import (
+ AudioLDMPipeline,
+ )
+else:
+ from .pipeline_audioldm import AudioLDMPipeline
diff --git a/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a35ebabb57a682b2c2d4a7d348eaf97da81f70d8
Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0bb97f4b8a55a2041615dba295162342566e8839
Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a46b85fa2e2511efaa2b8c0407f29f09976ecdf
Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-310.pyc differ
diff --git a/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08672343a352fa7fb50609eb36eab4dd29340a08
Binary files /dev/null and b/diffusers/pipelines/audioldm/__pycache__/pipeline_audioldm.cpython-38.pyc differ
diff --git a/diffusers/pipelines/audioldm/pipeline_audioldm.py b/diffusers/pipelines/audioldm/pipeline_audioldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da8e809103e3a2670ddec96e6612122a70f4b80
--- /dev/null
+++ b/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -0,0 +1,566 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AudioLDMPipeline
+
+ >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A hammer hitting a wooden surface"
+ >>> audio = pipe(prompt).audios[0]
+ ```
+"""
+
+
+class AudioLDMPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-audio generation using AudioLDM.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations.
+ text_encoder ([`ClapTextModelWithProjection`]):
+ Frozen text-encoder. AudioLDM uses the text portion of
+ [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection),
+ specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
+ tokenizer ([`PreTrainedTokenizer`]):
+ Tokenizer of class
+ [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ vocoder ([`SpeechT5HifiGan`]):
+ Vocoder of class
+ [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan).
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: ClapTextModelWithProjection,
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ vocoder: SpeechT5HifiGan,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ vocoder=vocoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_waveforms_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device (`torch.device`):
+ torch device
+ num_waveforms_per_prompt (`int`):
+ number of waveforms that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLAP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ )
+ prompt_embeds = prompt_embeds.text_embeds
+ # additional L_2 normalization over each hidden-state
+ prompt_embeds = F.normalize(prompt_embeds, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ (
+ bs_embed,
+ seq_len,
+ ) = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_input_ids = uncond_input.input_ids.to(device)
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input_ids,
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.text_embeds
+ # additional L_2 normalization over each hidden-state
+ negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ mel_spectrogram = self.vae.decode(latents).sample
+ return mel_spectrogram
+
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
+ if mel_spectrogram.dim() == 4:
+ mel_spectrogram = mel_spectrogram.squeeze(1)
+
+ waveform = self.vocoder(mel_spectrogram)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ waveform = waveform.cpu().float()
+ return waveform
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ audio_length_in_s,
+ vocoder_upsample_factor,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
+ if audio_length_in_s < min_audio_length_in_s:
+ raise ValueError(
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
+ f"is {audio_length_in_s}."
+ )
+
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
+ raise ValueError(
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
+ f"{self.vae_scale_factor}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ audio_length_in_s: Optional[float] = None,
+ num_inference_steps: int = 10,
+ guidance_scale: float = 2.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_waveforms_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ output_type: Optional[str] = "np",
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ audio_length_in_s (`int`, *optional*, defaults to 5.12):
+ The length of the generated audio sample in seconds.
+ num_inference_steps (`int`, *optional*, defaults to 10):
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 2.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`,
+ usually at the expense of lower sound quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
+ The number of waveforms to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generate image. Choose between:
+ - `"np"`: Return Numpy `np.ndarray` objects.
+ - `"pt"`: Return PyTorch `torch.Tensor` objects.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated audios.
+ """
+ # 0. Convert audio input length from seconds to spectrogram height
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
+
+ if audio_length_in_s is None:
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
+
+ height = int(audio_length_in_s / vocoder_upsample_factor)
+
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
+ if height % self.vae_scale_factor != 0:
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
+ logger.info(
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
+ f"denoising process."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ audio_length_in_s,
+ vocoder_upsample_factor,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_waveforms_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_waveforms_per_prompt,
+ num_channels_latents,
+ height,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=None,
+ class_labels=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ mel_spectrogram = self.decode_latents(latents)
+
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
+
+ audio = audio[:, :original_waveform_length]
+
+ if output_type == "np":
+ audio = audio.numpy()
+
+ if not return_dict:
+ return (audio,)
+
+ return AudioPipelineOutput(audios=audio)
diff --git a/diffusers/pipelines/consistency_models/__init__.py b/diffusers/pipelines/consistency_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd78ddb3aae232a734bd911e92d8c9a07019945d
--- /dev/null
+++ b/diffusers/pipelines/consistency_models/__init__.py
@@ -0,0 +1 @@
+from .pipeline_consistency_models import ConsistencyModelPipeline
diff --git a/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71116b6dd095d20b3524478bf2df6b1c0fa04342
Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad3e0b8225c672098a46a7f4c10d453a7f27c727
Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d22abe15e344025bd93db9e18279506d0ff052aa
Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-310.pyc differ
diff --git a/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2614bd0e8fce455ba6e3009181c16180e2984b81
Binary files /dev/null and b/diffusers/pipelines/consistency_models/__pycache__/pipeline_consistency_models.cpython-38.pyc differ
diff --git a/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4af7afe5adf99aef45f8ff0578cefa39549182
--- /dev/null
+++ b/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -0,0 +1,293 @@
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...schedulers import CMStochasticIterativeScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+
+ >>> from diffusers import ConsistencyModelPipeline
+
+ >>> device = "cuda"
+ >>> # Load the cd_imagenet64_l2 checkpoint.
+ >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
+ >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ >>> pipe.to(device)
+
+ >>> # Onestep Sampling
+ >>> image = pipe(num_inference_steps=1).images[0]
+ >>> image.save("cd_imagenet64_l2_onestep_sample.png")
+
+ >>> # Onestep sampling, class-conditional image generation
+ >>> # ImageNet-64 class label 145 corresponds to king penguins
+ >>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
+ >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
+
+ >>> # Multistep sampling, class-conditional image generation
+ >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
+ >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
+ >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
+ >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
+ ```
+"""
+
+
+class ConsistencyModelPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1].
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
+ https://arxiv.org/pdf/2303.01469
+
+ Args:
+ unet ([`UNet2DModel`]):
+ Unconditional or class-conditional U-Net architecture to denoise image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible
+ with [`CMStochasticIterativeScheduler`].
+ """
+
+ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ )
+
+ self.safety_checker = None
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels, height, width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Follows diffusers.VaeImageProcessor.postprocess
+ def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"):
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(
+ f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
+ )
+
+ # Equivalent to diffusers.VaeImageProcessor.denormalize
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ if output_type == "pt":
+ return sample
+
+ # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "np":
+ return sample
+
+ # Output_type must be 'pil'
+ sample = self.numpy_to_pil(sample)
+ return sample
+
+ def prepare_class_labels(self, batch_size, device, class_labels=None):
+ if self.unet.config.num_class_embeds is not None:
+ if isinstance(class_labels, list):
+ class_labels = torch.tensor(class_labels, dtype=torch.int)
+ elif isinstance(class_labels, int):
+ assert batch_size == 1, "Batch size must be 1 if classes is an int"
+ class_labels = torch.tensor([class_labels], dtype=torch.int)
+ elif class_labels is None:
+ # Randomly generate batch_size class labels
+ # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
+ class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
+ class_labels = class_labels.to(device)
+ else:
+ class_labels = None
+ return class_labels
+
+ def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
+ if num_inference_steps is None and timesteps is None:
+ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
+
+ if num_inference_steps is not None and timesteps is not None:
+ logger.warning(
+ f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
+ " `timesteps` will be used over `num_inference_steps`."
+ )
+
+ if latents is not None:
+ expected_shape = (batch_size, 3, img_size, img_size)
+ if latents.shape != expected_shape:
+ raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ batch_size: int = 1,
+ class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
+ num_inference_steps: int = 1,
+ timesteps: List[int] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
+ Optional class labels for conditioning class-conditional consistency models. Will not be used if the
+ model is not class-conditional.
+ num_inference_steps (`int`, *optional*, defaults to 1):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Prepare call parameters
+ img_size = self.unet.config.sample_size
+ device = self._execution_device
+
+ # 1. Check inputs
+ self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
+
+ # 2. Prepare image latents
+ # Sample image latents x_0 ~ N(0, sigma_0^2 * I)
+ sample = self.prepare_latents(
+ batch_size=batch_size,
+ num_channels=self.unet.config.in_channels,
+ height=img_size,
+ width=img_size,
+ dtype=self.unet.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 3. Handle class_labels for class-conditional models
+ class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Denoising loop
+ # Multistep sampling: implements Algorithm 1 in the paper
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ scaled_sample = self.scheduler.scale_model_input(sample, t)
+ model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
+
+ sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
+
+ # call the callback, if provided
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, sample)
+
+ # 6. Post-process image sample
+ image = self.postprocess_image(sample, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/controlnet/__init__.py b/diffusers/pipelines/controlnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd0b9a724e5e0a1a5cb864f0e5b8f698859f8d8a
--- /dev/null
+++ b/diffusers/pipelines/controlnet/__init__.py
@@ -0,0 +1,27 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_flax_available,
+ is_invisible_watermark_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .multicontrolnet import MultiControlNetModel
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
+
+
+if is_transformers_available() and is_flax_available():
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
diff --git a/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a54a503882d87a2f553efc35da9d78a8a6677698
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..975ff33fecdac637658cbe620de8a9fd2f835857
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..782bedb4fc75f18c0dca8c55e27f508d3d8783ce
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bd510e9041a69bffa495e17860f00830e362683
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b54423c3dbfbaadda20214188cde3881cb87e54c
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0f519310705b76bb25f46427bda52ae2d46da70
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6618d4ab24658e0cb462ea9a7d30fbcb39245a46
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83ba08d05a298ada1d5a75df37750ee494589635
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c82edc68f1680a4b1f86323ab98e5801c0f7d02e
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a001913f00c91cc7a5cd50189ee2cb6a028f6bc2
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c06e4f57bdec1752894cb63b55b82d8a1240548
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-310.pyc differ
diff --git a/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2de3a54fd7b28f2b735ff866d7ac38e39f75f3ff
Binary files /dev/null and b/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-38.pyc differ
diff --git a/diffusers/pipelines/controlnet/multicontrolnet.py b/diffusers/pipelines/controlnet/multicontrolnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2df8c7edd7aee9b374fc687d225da4fa827b24a5
--- /dev/null
+++ b/diffusers/pipelines/controlnet/multicontrolnet.py
@@ -0,0 +1,185 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...models.controlnet import ControlNetModel, ControlNetOutput
+from ...models.modeling_utils import ModelMixin
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MultiControlNetModel(ModelMixin):
+ r"""
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
+
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
+ compatible with `ControlNetModel`.
+
+ Args:
+ controlnets (`List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ down_samples, mid_sample = controlnet(
+ sample=sample,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ class_labels=class_labels,
+ timestep_cond=timestep_cond,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guess_mode=guess_mode,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
+ return down_block_res_samples, mid_block_res_sample
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format pytorch_model..bin.
+ """
+ idx = 0
+ model_path_to_save = save_directory
+ for controlnet in self.nets:
+ controlnet.save_pretrained(
+ model_path_to_save,
+ is_main_process=is_main_process,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ )
+
+ idx += 1
+ model_path_to_save = model_path_to_save + f"_{idx}"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_path (`os.PathLike`):
+ A path to a *directory* containing model weights saved using
+ [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
+ `./my_model_directory/controlnet`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
+ """
+ idx = 0
+ controlnets = []
+
+ # load controlnet and append to list until no controlnet directory exists anymore
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
+ model_path_to_load = pretrained_model_path
+ while os.path.isdir(model_path_to_load):
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
+ controlnets.append(controlnet)
+
+ idx += 1
+ model_path_to_load = pretrained_model_path + f"_{idx}"
+
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
+
+ if len(controlnets) == 0:
+ raise ValueError(
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
+ )
+
+ return cls(controlnets)
diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet.py b/diffusers/pipelines/controlnet/pipeline_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fcc0c67edf2b6e635de46821da519922cc48f17
--- /dev/null
+++ b/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -0,0 +1,1012 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_compiled_module,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from .multicontrolnet import MultiControlNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> image = np.array(image)
+
+ >>> # get canny image
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> # remove following line if xformers is not installed
+ >>> pipe.enable_xformers_memory_efficient_attention()
+
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
+ ... ).images[0]
+ ```
+"""
+
+
+class StableDiffusionControlNetPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+ as a list, the outputs from each ControlNet are added together to create one combined additional
+ conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c29a00a3542b6347a550a3c05d3776de19425273
--- /dev/null
+++ b/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -0,0 +1,1105 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_compiled_module,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from .multicontrolnet import MultiControlNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> np_image = np.array(image)
+
+ >>> # get canny image
+ >>> np_image = cv2.Canny(np_image, 100, 200)
+ >>> np_image = np_image[:, :, None]
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ >>> canny_image = Image.fromarray(np_image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... image=image,
+ ... control_image=canny_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class StableDiffusionControlNetImg2ImgPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+ as a list, the outputs from each ControlNet are added together to create one combined additional
+ conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ control_image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accpet
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ control_image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ # 4. Prepare image
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7481a0d4326c4560beafd3bde6631792de964d3
--- /dev/null
+++ b/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -0,0 +1,1355 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_compiled_module,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from .multicontrolnet import MultiControlNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+ ... )
+ >>> init_image = init_image.resize((512, 512))
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
+
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
+ ... )
+ >>> mask_image = mask_image.resize((512, 512))
+
+
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
+ ... return image
+
+
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> image = pipe(
+ ... "a handsome man with ray-ban sunglasses",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... eta=1.0,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... control_image=control_image,
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
+def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionControlNetInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+
+
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
+ as well as default text-to-image stable diffusion checkpoints, such as
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
+
+
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+ as a list, the outputs from each ControlNet are added together to create one combined additional
+ conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ control_image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, image)
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ control_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd4338946e6ff2972c38d9dc8ea01a404d77d987
--- /dev/null
+++ b/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -0,0 +1,960 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_compiled_module,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from .multicontrolnet import MultiControlNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # To be updated when there's a useful ControlNet checkpoint
+ >>> # compatible with SDXL.
+ ```
+"""
+
+
+class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+ as a list, the outputs from each ControlNet are added together to create one combined additional
+ conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ raise ValueError("MultiControlNet is not yet supported.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.watermark = StableDiffusionXLWatermarker()
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ original_size: Tuple[int, int] = (1024, 1024),
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = (1024, 1024),
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ TODO
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
+ containing the output images.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
+
+ # 7.2 Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.watermark.apply_watermark(image)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..872297605683485544cdb12217bf679d5223a56c
--- /dev/null
+++ b/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -0,0 +1,537 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from PIL import Image
+from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
+from ...schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
+from ..pipeline_flax_utils import FlaxDiffusionPipeline
+from ..stable_diffusion import FlaxStableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Set to True to use python for loop instead of jax.fori_loop for easier debugging
+DEBUG = False
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import jax
+ >>> import numpy as np
+ >>> import jax.numpy as jnp
+ >>> from flax.jax_utils import replicate
+ >>> from flax.training.common_utils import shard
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+ >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
+
+
+ >>> def image_grid(imgs, rows, cols):
+ ... w, h = imgs[0].size
+ ... grid = Image.new("RGB", size=(cols * w, rows * h))
+ ... for i, img in enumerate(imgs):
+ ... grid.paste(img, box=(i % cols * w, i // cols * h))
+ ... return grid
+
+
+ >>> def create_key(seed=0):
+ ... return jax.random.PRNGKey(seed)
+
+
+ >>> rng = create_key(0)
+
+ >>> # get canny image
+ >>> canny_image = load_image(
+ ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
+ ... )
+
+ >>> prompts = "best quality, extremely detailed"
+ >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
+ ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
+ ... )
+ >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
+ ... )
+ >>> params["controlnet"] = controlnet_params
+
+ >>> num_samples = jax.device_count()
+ >>> rng = jax.random.split(rng, jax.device_count())
+
+ >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
+ >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
+ >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
+
+ >>> p_params = replicate(params)
+ >>> prompt_ids = shard(prompt_ids)
+ >>> negative_prompt_ids = shard(negative_prompt_ids)
+ >>> processed_image = shard(processed_image)
+
+ >>> output = pipe(
+ ... prompt_ids=prompt_ids,
+ ... image=processed_image,
+ ... params=p_params,
+ ... prng_seed=rng,
+ ... num_inference_steps=50,
+ ... neg_prompt_ids=negative_prompt_ids,
+ ... jit=True,
+ ... ).images
+
+ >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
+ >>> output_images = image_grid(output_images, num_samples // 4, 4)
+ >>> output_images.save("generated_image.png")
+ ```
+"""
+
+
+class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`FlaxControlNetModel`]:
+ Provides additional conditioning to the unet during the denoising process.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ controlnet: FlaxControlNetModel,
+ scheduler: Union[
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_text_inputs(self, prompt: Union[str, List[str]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+
+ return text_input.input_ids
+
+ def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
+ if not isinstance(image, (Image.Image, list)):
+ raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
+
+ if isinstance(image, Image.Image):
+ image = [image]
+
+ processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
+
+ return processed_images
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int,
+ guidance_scale: float,
+ latents: Optional[jnp.array] = None,
+ neg_prompt_ids: Optional[jnp.array] = None,
+ controlnet_conditioning_scale: float = 1.0,
+ ):
+ height, width = image.shape[-2:]
+ if height % 64 != 0 or width % 64 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
+ context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ image = jnp.concatenate([image] * 2)
+
+ latents_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
+ {"params": params["controlnet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
+ )
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * params["scheduler"].init_noise_sigma
+
+ if DEBUG:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int = 50,
+ guidance_scale: Union[float, jnp.array] = 7.5,
+ latents: jnp.array = None,
+ neg_prompt_ids: jnp.array = None,
+ controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
+ return_dict: bool = True,
+ jit: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt_ids (`jnp.array`):
+ The prompt or prompts to guide the image generation.
+ image (`jnp.array`):
+ Array representing the ControlNet input condition. ControlNet use this input condition to generate
+ guidance to Unet.
+ params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
+ prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+
+ height, width = image.shape[-2:]
+
+ if isinstance(guidance_scale, float):
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
+ if len(prompt_ids.shape) > 2:
+ # Assume sharded
+ guidance_scale = guidance_scale[:, None]
+
+ if isinstance(controlnet_conditioning_scale, float):
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
+ controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
+ if len(prompt_ids.shape) > 2:
+ # Assume sharded
+ controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ controlnet_conditioning_scale,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ controlnet_conditioning_scale,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
+ images = np.array(images)
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
+
+
+# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
+# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
+@partial(
+ jax.pmap,
+ in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),
+ static_broadcasted_argnums=(0, 5),
+)
+def _p_generate(
+ pipe,
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ controlnet_conditioning_scale,
+):
+ return pipe._generate(
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ controlnet_conditioning_scale,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)
+
+
+def preprocess(image, dtype):
+ image = image.convert("RGB")
+ w, h = image.size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = jnp.array(image).astype(dtype) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return image
diff --git a/diffusers/pipelines/dance_diffusion/__init__.py b/diffusers/pipelines/dance_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d7f8ff9807083a10c844f7003cf0696d8258a3
--- /dev/null
+++ b/diffusers/pipelines/dance_diffusion/__init__.py
@@ -0,0 +1 @@
+from .pipeline_dance_diffusion import DanceDiffusionPipeline
diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f42451c15f481fc23eed91a04e95cd1de62a0333
Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f7111675e18cc7da8d3e71f3ef2e1aaf806f51c
Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7a32766b9f30a474d2ab9b08bebcbb0a0c134e0
Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec664e0be7f909df39bd25a322e1f4bea0b7df71
Binary files /dev/null and b/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3eb32273b6da97ac807f3e4fbebb5edd53bb879
--- /dev/null
+++ b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class DanceDiffusionPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded audio.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded audio. Can be one of
+ [`IPNDMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 100,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ audio_length_in_s: Optional[float] = None,
+ return_dict: bool = True,
+ ) -> Union[AudioPipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of audio samples to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at
+ the expense of slower inference.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
+ The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
+ `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated audio.
+ """
+
+ if audio_length_in_s is None:
+ audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
+
+ sample_size = audio_length_in_s * self.unet.config.sample_rate
+
+ down_scale_factor = 2 ** len(self.unet.up_blocks)
+ if sample_size < 3 * down_scale_factor:
+ raise ValueError(
+ f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
+ f" {3 * down_scale_factor / self.unet.config.sample_rate}."
+ )
+
+ original_sample_size = int(sample_size)
+ if sample_size % down_scale_factor != 0:
+ sample_size = (
+ (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1
+ ) * down_scale_factor
+ logger.info(
+ f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled"
+ f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising"
+ " process."
+ )
+ sample_size = int(sample_size)
+
+ dtype = next(self.unet.parameters()).dtype
+ shape = (batch_size, self.unet.config.in_channels, sample_size)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ audio = randn_tensor(shape, generator=generator, device=self._execution_device, dtype=dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
+ self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(audio, t).sample
+
+ # 2. compute previous audio sample: x_t -> t_t-1
+ audio = self.scheduler.step(model_output, t, audio).prev_sample
+
+ audio = audio.clamp(-1, 1).float().cpu().numpy()
+
+ audio = audio[:, :, :original_sample_size]
+
+ if not return_dict:
+ return (audio,)
+
+ return AudioPipelineOutput(audios=audio)
diff --git a/diffusers/pipelines/ddim/__init__.py b/diffusers/pipelines/ddim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e8118e75e7e4352f8efb12552ba9fff4bf491c
--- /dev/null
+++ b/diffusers/pipelines/ddim/__init__.py
@@ -0,0 +1 @@
+from .pipeline_ddim import DDIMPipeline
diff --git a/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..045acf07f2783feaeb5484b3aba4dd0d3bb9fae4
Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..505a37aec1e246f5a08a8b40e0bf2c0caf600675
Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49ba3c57ea6d0ed4fe347fb014053d1fe9378766
Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-310.pyc differ
diff --git a/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..410712086c01574b06e62b3eb124ee5bf0682230
Binary files /dev/null and b/diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc differ
diff --git a/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/pipelines/ddim/pipeline_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..c24aa6c797935cd90ff9dad25a5ca5c07686216d
--- /dev/null
+++ b/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...schedulers import DDIMScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDIMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ # make sure scheduler can always be converted to DDIM
+ scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ use_clipped_model_output: Optional[bool] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ use_clipped_model_output (`bool`, *optional*, defaults to `None`):
+ if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
+ downstream to the scheduler. So use `None` for schedulers which don't support this argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(
+ model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
+ ).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb228ee012e80493b617b314c867ecadba7ca1ce
--- /dev/null
+++ b/diffusers/pipelines/ddpm/__init__.py
@@ -0,0 +1 @@
+from .pipeline_ddpm import DDPMPipeline
diff --git a/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72457c7e29b43d24a3d06663e01789182fcfe88c
Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac1647bb5e9f229b9d4d3f55574b4c66c0f20ca2
Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96cf144e91518159b1b81690e3bce51b2e7addba
Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-310.pyc differ
diff --git a/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71e5a15d92a2c283c002bc12dc95254761490069
Binary files /dev/null and b/diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-38.pyc differ
diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4290daf852c2f3204a64b9955c9b53089d64bbc
--- /dev/null
+++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -0,0 +1,105 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 1000,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = randn_tensor(image_shape, generator=generator)
+ image = image.to(self.device)
+ else:
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/deepfloyd_if/__init__.py b/diffusers/pipelines/deepfloyd_if/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93414f20e7339a147ffa2d3dd36c871dfecda8e4
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/__init__.py
@@ -0,0 +1,54 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+
+from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
+from .timesteps import (
+ fast27_timesteps,
+ smart27_timesteps,
+ smart50_timesteps,
+ smart100_timesteps,
+ smart185_timesteps,
+ super27_timesteps,
+ super40_timesteps,
+ super100_timesteps,
+)
+
+
+@dataclass
+class IFPipelineOutput(BaseOutput):
+ """
+ Args:
+ Output class for Stable Diffusion pipelines.
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content or a watermark. `None` if safety checking could not be performed.
+ watermark_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
+ checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_detected: Optional[List[bool]]
+ watermark_detected: Optional[List[bool]]
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_if import IFPipeline
+ from .pipeline_if_img2img import IFImg2ImgPipeline
+ from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline
+ from .pipeline_if_inpainting import IFInpaintingPipeline
+ from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline
+ from .pipeline_if_superresolution import IFSuperResolutionPipeline
+ from .safety_checker import IFSafetyChecker
+ from .watermark import IFWatermarker
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..679bb7a4b36455c1953311722e404c9b701d7238
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a41bbec97202b97f94fce6eb87225b9c10d8b6ee
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ddeb021863501dd9e09fbd29954bfae49186fa5
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b22bd7001afa5362757b8ccde63ed646e36512b
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79993dfdc9021f21f5388f581bb332fb3c38d78f
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce6c54ff9b838c60b85504809d336a3464516a48
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba91903c44975c3df3a288d8d7b36087bd052680
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..674f4d2500046898a6e2433da2762e7321be3298
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e92b7ae88e9386fdd0be786acac33c3108e7ff93
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..203f15a76f7b317903f2d4c4e4426abb1163f5ec
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d565672b37f5dfca573826d524aa67db5030fe09
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5dd8f295bb53d067f27f64f6098dcf3a469571c
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2917ffa56dd37d2bc6040d6805cab4ab7a9631d6
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c579f7ab530c509c3c6b9e2bb2a320d0a35bd239
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25f33752dc5dc47c56e105355c03e5ab1537bf8c
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c09292ed709ee00f40af3eec29072a77d5494269
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31a5962c75f4bed23e9853698aa73a9fd9ce1973
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29c3af9e1597507385b5f99d1e07ba291d4f2144
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbf06ca889ba0615e4085ee0420e9aebb7662a1e
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35450ba13ad33d8c87985240ba7b3b388c1fe9d4
Binary files /dev/null and b/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-38.pyc differ
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/diffusers/pipelines/deepfloyd_if/pipeline_if.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffcb1ab32d357dd0c546bd96def75207752e06cb
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -0,0 +1,816 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+
+ >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+ >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt"
+ ... ).images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> safety_modules = {
+ ... "feature_extractor": pipe.feature_extractor,
+ ... "safety_checker": pipe.safety_checker,
+ ... "watermarker": pipe.watermarker,
+ ... }
+ >>> super_res_2_pipe = DiffusionPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
+ ... )
+ >>> super_res_2_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_2_pipe(
+ ... prompt=prompt,
+ ... image=image,
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+"""
+
+
+class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ @torch.no_grad()
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
+ shape = (batch_size, num_channels, height, width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
+ return intermediate_images
+
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 100,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Define call parameters
+ height = height or self.unet.config.sample_size
+ width = width or self.unet.config.sample_size
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare intermediate images
+ intermediate_images = self.prepare_intermediate_images(
+ batch_size * num_images_per_prompt,
+ self.unet.config.in_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = (
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
+ )
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 11. Apply watermark
+ if self.watermarker is not None:
+ image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4c5229e1c48b4a561f24d88a776007cdf22f1a
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -0,0 +1,940 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
+ w, h = images.size
+
+ coef = w / h
+
+ w, h = img_size, img_size
+
+ if coef >= 1:
+ w = int(round(img_size / 8 * coef) * 8)
+ else:
+ h = int(round(img_size / 8 / coef) * 8)
+
+ images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
+
+ return images
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from io import BytesIO
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> response = requests.get(url)
+ >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> original_image = original_image.resize((768, 512))
+
+ >>> pipe = IFImg2ImgPipeline.from_pretrained(
+ ... "DeepFloyd/IF-I-XL-v1.0",
+ ... variant="fp16",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A fantasy landscape in style minecraft"
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+ >>> image = pipe(
+ ... image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... output_type="pt",
+ ... ).images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0",
+ ... text_encoder=None,
+ ... variant="fp16",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image,
+ ... original_image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+"""
+
+
+class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ @torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ batch_size,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if isinstance(image, list):
+ check_image_type = image[0]
+ else:
+ check_image_type = image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ image_batch_size = image.shape[0]
+ elif isinstance(image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(image, np.ndarray):
+ image_batch_size = image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
+ if not isinstance(image, list):
+ image = [image]
+
+ def numpy_to_pt(images):
+ if images.ndim == 3:
+ images = images[..., None]
+
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
+ return images
+
+ if isinstance(image[0], PIL.Image.Image):
+ new_image = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = resize(image_, self.unet.sample_size)
+ image_ = np.array(image_)
+ image_ = image_.astype(np.float32)
+ image_ = image_ / 127.5 - 1
+ new_image.append(image_)
+
+ image = new_image
+
+ image = np.stack(image, axis=0) # to np
+ image = numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+ image = numpy_to_pt(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ return image
+
+ def get_timesteps(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_intermediate_images(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
+ ):
+ _, channels, height, width = image.shape
+
+ batch_size = batch_size * num_images_per_prompt
+
+ shape = (batch_size, channels, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+ image = self.scheduler.add_noise(image, noise, timestep)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ strength: float = 0.7,
+ num_inference_steps: int = 80,
+ timesteps: List[int] = None,
+ guidance_scale: float = 10.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ self.check_inputs(
+ prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ dtype = prompt_embeds.dtype
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+
+ # 5. Prepare intermediate images
+ image = self.preprocess_image(image)
+ image = image.to(device=device, dtype=dtype)
+
+ noise_timestep = timesteps[0:1]
+ noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
+
+ intermediate_images = self.prepare_intermediate_images(
+ image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = (
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
+ )
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 11. Apply watermark
+ if self.watermarker is not None:
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ea0bafa51a011d4c1741c46090aaae55b1e06e5
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -0,0 +1,1058 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize
+def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
+ w, h = images.size
+
+ coef = w / h
+
+ w, h = img_size, img_size
+
+ if coef >= 1:
+ w = int(round(img_size / 8 * coef) * 8)
+ else:
+ h = int(round(img_size / 8 / coef) * 8)
+
+ images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
+
+ return images
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from io import BytesIO
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> response = requests.get(url)
+ >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> original_image = original_image.resize((768, 512))
+
+ >>> pipe = IFImg2ImgPipeline.from_pretrained(
+ ... "DeepFloyd/IF-I-XL-v1.0",
+ ... variant="fp16",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A fantasy landscape in style minecraft"
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+ >>> image = pipe(
+ ... image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... output_type="pt",
+ ... ).images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0",
+ ... text_encoder=None,
+ ... variant="fp16",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image,
+ ... original_image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+"""
+
+
+class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+ image_noising_scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ image_noising_scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if unet.config.in_channels != 6:
+ logger.warn(
+ "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ @torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ original_image,
+ batch_size,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # image
+
+ if isinstance(image, list):
+ check_image_type = image[0]
+ else:
+ check_image_type = image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ image_batch_size = image.shape[0]
+ elif isinstance(image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(image, np.ndarray):
+ image_batch_size = image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
+
+ # original_image
+
+ if isinstance(original_image, list):
+ check_image_type = original_image[0]
+ else:
+ check_image_type = original_image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(original_image, list):
+ image_batch_size = len(original_image)
+ elif isinstance(original_image, torch.Tensor):
+ image_batch_size = original_image.shape[0]
+ elif isinstance(original_image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(original_image, np.ndarray):
+ image_batch_size = original_image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image
+ def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor:
+ if not isinstance(image, list):
+ image = [image]
+
+ def numpy_to_pt(images):
+ if images.ndim == 3:
+ images = images[..., None]
+
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
+ return images
+
+ if isinstance(image[0], PIL.Image.Image):
+ new_image = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = resize(image_, self.unet.sample_size)
+ image_ = np.array(image_)
+ image_ = image_.astype(np.float32)
+ image_ = image_ / 127.5 - 1
+ new_image.append(image_)
+
+ image = new_image
+
+ image = np.stack(image, axis=0) # to np
+ image = numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+ image = numpy_to_pt(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ return image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image
+ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor:
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]
+
+ image = np.stack(image, axis=0) # to np
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image[0], np.ndarray):
+ image = np.stack(image, axis=0) # to np
+ if image.ndim == 5:
+ image = image[0]
+
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
+ dims = image[0].ndim
+
+ if dims == 3:
+ image = torch.stack(image, dim=0)
+ elif dims == 4:
+ image = torch.concat(image, dim=0)
+ else:
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
+
+ image = image.to(device=device, dtype=self.unet.dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images
+ def prepare_intermediate_images(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
+ ):
+ _, channels, height, width = image.shape
+
+ batch_size = batch_size * num_images_per_prompt
+
+ shape = (batch_size, channels, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+ image = self.scheduler.add_noise(image, noise, timestep)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor],
+ original_image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ strength: float = 0.8,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 250,
+ clean_caption: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ original_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The original image that `image` was varied from.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to 250):
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ self.check_inputs(
+ prompt,
+ image,
+ original_image,
+ batch_size,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ dtype = prompt_embeds.dtype
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+
+ # 5. prepare original image
+ original_image = self.preprocess_original_image(original_image)
+ original_image = original_image.to(device=device, dtype=dtype)
+
+ # 6. Prepare intermediate images
+ noise_timestep = timesteps[0:1]
+ noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
+
+ intermediate_images = self.prepare_intermediate_images(
+ original_image,
+ noise_timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare upscaled image and noise level
+ _, _, height, width = original_image.shape
+
+ image = self.preprocess_image(image, num_images_per_prompt, device)
+
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
+
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
+
+ if do_classifier_free_guidance:
+ noise_level = torch.cat([noise_level] * 2)
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
+
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=noise_level,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 10. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 11. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 12. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 13. Apply watermark
+ if self.watermarker is not None:
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 10. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 11. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b11bf780de472e92325b749e516bdc7883c22297
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -0,0 +1,1059 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize
+def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
+ w, h = images.size
+
+ coef = w / h
+
+ w, h = img_size, img_size
+
+ if coef >= 1:
+ w = int(round(img_size / 8 * coef) * 8)
+ else:
+ h = int(round(img_size / 8 / coef) * 8)
+
+ images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
+
+ return images
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from io import BytesIO
+
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
+ >>> response = requests.get(url)
+ >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> original_image = original_image
+
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
+ >>> response = requests.get(url)
+ >>> mask_image = Image.open(BytesIO(response.content))
+ >>> mask_image = mask_image
+
+ >>> pipe = IFInpaintingPipeline.from_pretrained(
+ ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "blue sunglasses"
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+ >>> image = pipe(
+ ... image=original_image,
+ ... mask_image=mask_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... output_type="pt",
+ ... ).images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image,
+ ... mask_image=mask_image,
+ ... original_image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+"""
+
+
+class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ @torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ batch_size,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # image
+
+ if isinstance(image, list):
+ check_image_type = image[0]
+ else:
+ check_image_type = image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ image_batch_size = image.shape[0]
+ elif isinstance(image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(image, np.ndarray):
+ image_batch_size = image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
+
+ # mask_image
+
+ if isinstance(mask_image, list):
+ check_image_type = mask_image[0]
+ else:
+ check_image_type = mask_image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(mask_image, list):
+ image_batch_size = len(mask_image)
+ elif isinstance(mask_image, torch.Tensor):
+ image_batch_size = mask_image.shape[0]
+ elif isinstance(mask_image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(mask_image, np.ndarray):
+ image_batch_size = mask_image.shape[0]
+ else:
+ assert False
+
+ if image_batch_size != 1 and batch_size != image_batch_size:
+ raise ValueError(
+ f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image
+ def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
+ if not isinstance(image, list):
+ image = [image]
+
+ def numpy_to_pt(images):
+ if images.ndim == 3:
+ images = images[..., None]
+
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
+ return images
+
+ if isinstance(image[0], PIL.Image.Image):
+ new_image = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = resize(image_, self.unet.sample_size)
+ image_ = np.array(image_)
+ image_ = image_.astype(np.float32)
+ image_ = image_ / 127.5 - 1
+ new_image.append(image_)
+
+ image = new_image
+
+ image = np.stack(image, axis=0) # to np
+ image = numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+ image = numpy_to_pt(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ return image
+
+ def preprocess_mask_image(self, mask_image) -> torch.Tensor:
+ if not isinstance(mask_image, list):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image[0], torch.Tensor):
+ mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0)
+
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+
+ elif isinstance(mask_image[0], PIL.Image.Image):
+ new_mask_image = []
+
+ for mask_image_ in mask_image:
+ mask_image_ = mask_image_.convert("L")
+ mask_image_ = resize(mask_image_, self.unet.sample_size)
+ mask_image_ = np.array(mask_image_)
+ mask_image_ = mask_image_[None, None, :]
+ new_mask_image.append(mask_image_)
+
+ mask_image = new_mask_image
+
+ mask_image = np.concatenate(mask_image, axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ elif isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_intermediate_images(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None
+ ):
+ image_batch_size, channels, height, width = image.shape
+
+ batch_size = batch_size * num_images_per_prompt
+
+ shape = (batch_size, channels, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+ noised_image = self.scheduler.add_noise(image, noise, timestep)
+
+ image = (1 - mask_image) * image + mask_image * noised_image
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ mask_image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ batch_size,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ dtype = prompt_embeds.dtype
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+
+ # 5. Prepare intermediate images
+ image = self.preprocess_image(image)
+ image = image.to(device=device, dtype=dtype)
+
+ mask_image = self.preprocess_mask_image(mask_image)
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ if mask_image.shape[0] == 1:
+ mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
+ else:
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
+
+ noise_timestep = timesteps[0:1]
+ noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
+
+ intermediate_images = self.prepare_intermediate_images(
+ image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = (
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
+ )
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ prev_intermediate_images = intermediate_images
+
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 11. Apply watermark
+ if self.watermarker is not None:
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 8. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 9. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..2971570aada23ca1f88d212b7f82eb6e9c2118eb
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -0,0 +1,1169 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize
+def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
+ w, h = images.size
+
+ coef = w / h
+
+ w, h = img_size, img_size
+
+ if coef >= 1:
+ w = int(round(img_size / 8 * coef) * 8)
+ else:
+ h = int(round(img_size / 8 / coef) * 8)
+
+ images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
+
+ return images
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+ >>> from io import BytesIO
+
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
+ >>> response = requests.get(url)
+ >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> original_image = original_image
+
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
+ >>> response = requests.get(url)
+ >>> mask_image = Image.open(BytesIO(response.content))
+ >>> mask_image = mask_image
+
+ >>> pipe = IFInpaintingPipeline.from_pretrained(
+ ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "blue sunglasses"
+
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+ >>> image = pipe(
+ ... image=original_image,
+ ... mask_image=mask_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... output_type="pt",
+ ... ).images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image,
+ ... mask_image=mask_image,
+ ... original_image=original_image,
+ ... prompt_embeds=prompt_embeds,
+ ... negative_prompt_embeds=negative_embeds,
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+ """
+
+
+class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+ image_noising_scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ image_noising_scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if unet.config.in_channels != 6:
+ logger.warn(
+ "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ @torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ original_image,
+ mask_image,
+ batch_size,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # image
+
+ if isinstance(image, list):
+ check_image_type = image[0]
+ else:
+ check_image_type = image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ image_batch_size = image.shape[0]
+ elif isinstance(image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(image, np.ndarray):
+ image_batch_size = image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
+
+ # original_image
+
+ if isinstance(original_image, list):
+ check_image_type = original_image[0]
+ else:
+ check_image_type = original_image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(original_image, list):
+ image_batch_size = len(original_image)
+ elif isinstance(original_image, torch.Tensor):
+ image_batch_size = original_image.shape[0]
+ elif isinstance(original_image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(original_image, np.ndarray):
+ image_batch_size = original_image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}"
+ )
+
+ # mask_image
+
+ if isinstance(mask_image, list):
+ check_image_type = mask_image[0]
+ else:
+ check_image_type = mask_image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(mask_image, list):
+ image_batch_size = len(mask_image)
+ elif isinstance(mask_image, torch.Tensor):
+ image_batch_size = mask_image.shape[0]
+ elif isinstance(mask_image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(mask_image, np.ndarray):
+ image_batch_size = mask_image.shape[0]
+ else:
+ assert False
+
+ if image_batch_size != 1 and batch_size != image_batch_size:
+ raise ValueError(
+ f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image
+ def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor:
+ if not isinstance(image, list):
+ image = [image]
+
+ def numpy_to_pt(images):
+ if images.ndim == 3:
+ images = images[..., None]
+
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
+ return images
+
+ if isinstance(image[0], PIL.Image.Image):
+ new_image = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = resize(image_, self.unet.sample_size)
+ image_ = np.array(image_)
+ image_ = image_.astype(np.float32)
+ image_ = image_ / 127.5 - 1
+ new_image.append(image_)
+
+ image = new_image
+
+ image = np.stack(image, axis=0) # to np
+ image = numpy_to_pt(image) # to pt
+
+ elif isinstance(image[0], np.ndarray):
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
+ image = numpy_to_pt(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ return image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image
+ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor:
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]
+
+ image = np.stack(image, axis=0) # to np
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image[0], np.ndarray):
+ image = np.stack(image, axis=0) # to np
+ if image.ndim == 5:
+ image = image[0]
+
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
+ dims = image[0].ndim
+
+ if dims == 3:
+ image = torch.stack(image, dim=0)
+ elif dims == 4:
+ image = torch.concat(image, dim=0)
+ else:
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
+
+ image = image.to(device=device, dtype=self.unet.dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image
+ def preprocess_mask_image(self, mask_image) -> torch.Tensor:
+ if not isinstance(mask_image, list):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image[0], torch.Tensor):
+ mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0)
+
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+
+ elif isinstance(mask_image[0], PIL.Image.Image):
+ new_mask_image = []
+
+ for mask_image_ in mask_image:
+ mask_image_ = mask_image_.convert("L")
+ mask_image_ = resize(mask_image_, self.unet.sample_size)
+ mask_image_ = np.array(mask_image_)
+ mask_image_ = mask_image_[None, None, :]
+ new_mask_image.append(mask_image_)
+
+ mask_image = new_mask_image
+
+ mask_image = np.concatenate(mask_image, axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ elif isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images
+ def prepare_intermediate_images(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None
+ ):
+ image_batch_size, channels, height, width = image.shape
+
+ batch_size = batch_size * num_images_per_prompt
+
+ shape = (batch_size, channels, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+ noised_image = self.scheduler.add_noise(image, noise, timestep)
+
+ image = (1 - mask_image) * image + mask_image * noised_image
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor],
+ original_image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ mask_image: Union[
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
+ ] = None,
+ strength: float = 0.8,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 100,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 0,
+ clean_caption: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ original_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The original image that `image` was varied from.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to 0):
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ self.check_inputs(
+ prompt,
+ image,
+ original_image,
+ mask_image,
+ batch_size,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ dtype = prompt_embeds.dtype
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+
+ # 5. prepare original image
+ original_image = self.preprocess_original_image(original_image)
+ original_image = original_image.to(device=device, dtype=dtype)
+
+ # 6. prepare mask image
+ mask_image = self.preprocess_mask_image(mask_image)
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ if mask_image.shape[0] == 1:
+ mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
+ else:
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # 6. Prepare intermediate images
+ noise_timestep = timesteps[0:1]
+ noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
+
+ intermediate_images = self.prepare_intermediate_images(
+ original_image,
+ noise_timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ mask_image,
+ generator,
+ )
+
+ # 7. Prepare upscaled image and noise level
+ _, _, height, width = original_image.shape
+
+ image = self.preprocess_image(image, num_images_per_prompt, device)
+
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
+
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
+
+ if do_classifier_free_guidance:
+ noise_level = torch.cat([noise_level] * 2)
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
+
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=noise_level,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ prev_intermediate_images = intermediate_images
+
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 10. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 11. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 12. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 13. Apply watermark
+ if self.watermarker is not None:
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 10. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 11. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bdd01fe748e411c032dbb5316c6761ecd5716e5
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -0,0 +1,914 @@
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
+
+from ...loaders import LoraLoaderMixin
+from ...models import UNet2DConditionModel
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import IFPipelineOutput
+from .safety_checker import IFSafetyChecker
+from .watermark import IFWatermarker
+
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
+ >>> from diffusers.utils import pt_to_pil
+ >>> import torch
+
+ >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+ >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
+
+ >>> # save intermediate image
+ >>> pil_image = pt_to_pil(image)
+ >>> pil_image[0].save("./if_stage_I.png")
+
+ >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> super_res_1_pipe.enable_model_cpu_offload()
+
+ >>> image = super_res_1_pipe(
+ ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds
+ ... ).images
+ >>> image[0].save("./if_stage_II.png")
+ ```
+"""
+
+
+class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
+ tokenizer: T5Tokenizer
+ text_encoder: T5EncoderModel
+
+ unet: UNet2DConditionModel
+ scheduler: DDPMScheduler
+ image_noising_scheduler: DDPMScheduler
+
+ feature_extractor: Optional[CLIPImageProcessor]
+ safety_checker: Optional[IFSafetyChecker]
+
+ watermarker: Optional[IFWatermarker]
+
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ image_noising_scheduler: DDPMScheduler,
+ safety_checker: Optional[IFSafetyChecker],
+ feature_extractor: Optional[CLIPImageProcessor],
+ watermarker: Optional[IFWatermarker],
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if unet.config.in_channels != 6:
+ logger.warn(
+ "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
+ )
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ watermarker=watermarker,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+
+ if self.text_encoder is not None:
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
+
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
+ # previous model. This will cause both models to be present on the device at the same time.
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
+ # the GPU.
+ self.text_encoder_offload_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
+
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
+ self.unet_offload_hook = hook
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
+ def remove_all_hooks(self):
+ if is_accelerate_available():
+ from accelerate.hooks import remove_hook_from_module
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
+ if model is not None:
+ remove_hook_from_module(model, recurse=True)
+
+ self.unet_offload_hook = None
+ self.text_encoder_offload_hook = None
+ self.final_offload_hook = None
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ @torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
+ max_length = 77
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ dtype = self.unet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ attention_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ else:
+ negative_prompt_embeds = None
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ batch_size,
+ noise_level,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
+ raise ValueError(
+ f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
+ )
+
+ if isinstance(image, list):
+ check_image_type = image[0]
+ else:
+ check_image_type = image
+
+ if (
+ not isinstance(check_image_type, torch.Tensor)
+ and not isinstance(check_image_type, PIL.Image.Image)
+ and not isinstance(check_image_type, np.ndarray)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
+ f" {type(check_image_type)}"
+ )
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ elif isinstance(image, torch.Tensor):
+ image_batch_size = image.shape[0]
+ elif isinstance(image, PIL.Image.Image):
+ image_batch_size = 1
+ elif isinstance(image, np.ndarray):
+ image_batch_size = image.shape[0]
+ else:
+ assert False
+
+ if batch_size != image_batch_size:
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images
+ def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
+ shape = (batch_size, num_channels, height, width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
+ return intermediate_images
+
+ def preprocess_image(self, image, num_images_per_prompt, device):
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image]
+
+ image = np.stack(image, axis=0) # to np
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image[0], np.ndarray):
+ image = np.stack(image, axis=0) # to np
+ if image.ndim == 5:
+ image = image[0]
+
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
+ dims = image[0].ndim
+
+ if dims == 3:
+ image = torch.stack(image, dim=0)
+ elif dims == 4:
+ image = torch.concat(image, dim=0)
+ else:
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
+
+ image = image.to(device=device, dtype=self.unet.dtype)
+
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: int = None,
+ width: int = None,
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 250,
+ clean_caption: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
+ The image to be upscaled.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to 250):
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ or watermarked content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ self.check_inputs(
+ prompt,
+ image,
+ batch_size,
+ noise_level,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+
+ height = height or self.unet.config.sample_size
+ width = width or self.unet.config.sample_size
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clean_caption=clean_caption,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare intermediate images
+ num_channels = self.unet.config.in_channels // 2
+ intermediate_images = self.prepare_intermediate_images(
+ batch_size * num_images_per_prompt,
+ num_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare upscaled image and noise level
+ image = self.preprocess_image(image, num_images_per_prompt, device)
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
+
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
+
+ if do_classifier_free_guidance:
+ noise_level = torch.cat([noise_level] * 2)
+
+ # HACK: see comment in `enable_model_cpu_offload`
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
+ self.text_encoder_offload_hook.offload()
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
+
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=noise_level,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
+ noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ intermediate_images = self.scheduler.step(
+ noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, intermediate_images)
+
+ image = intermediate_images
+
+ if output_type == "pil":
+ # 9. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 10. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 11. Convert to PIL
+ image = self.numpy_to_pil(image)
+
+ # 12. Apply watermark
+ if self.watermarker is not None:
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
+ elif output_type == "pt":
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+ else:
+ # 9. Post-processing
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ # 10. Run safety checker
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, nsfw_detected, watermark_detected)
+
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
diff --git a/diffusers/pipelines/deepfloyd_if/safety_checker.py b/diffusers/pipelines/deepfloyd_if/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ffeed580bbea1514b11bf7a168a952328d8f424
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/safety_checker.py
@@ -0,0 +1,59 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class IFSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
+
+ self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
+ self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
+
+ @torch.no_grad()
+ def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
+ image_embeds = self.vision_model(clip_input)[0]
+
+ nsfw_detected = self.p_head(image_embeds)
+ nsfw_detected = nsfw_detected.flatten()
+ nsfw_detected = nsfw_detected > p_threshold
+ nsfw_detected = nsfw_detected.tolist()
+
+ if any(nsfw_detected):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ for idx, nsfw_detected_ in enumerate(nsfw_detected):
+ if nsfw_detected_:
+ images[idx] = np.zeros(images[idx].shape)
+
+ watermark_detected = self.w_head(image_embeds)
+ watermark_detected = watermark_detected.flatten()
+ watermark_detected = watermark_detected > w_threshold
+ watermark_detected = watermark_detected.tolist()
+
+ if any(watermark_detected):
+ logger.warning(
+ "Potential watermarked content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ for idx, watermark_detected_ in enumerate(watermark_detected):
+ if watermark_detected_:
+ images[idx] = np.zeros(images[idx].shape)
+
+ return images, nsfw_detected, watermark_detected
diff --git a/diffusers/pipelines/deepfloyd_if/timesteps.py b/diffusers/pipelines/deepfloyd_if/timesteps.py
new file mode 100644
index 0000000000000000000000000000000000000000..d44285c017bbb2ccffa4ae86dd77792a048625d9
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/timesteps.py
@@ -0,0 +1,579 @@
+fast27_timesteps = [
+ 999,
+ 800,
+ 799,
+ 600,
+ 599,
+ 500,
+ 400,
+ 399,
+ 377,
+ 355,
+ 333,
+ 311,
+ 288,
+ 266,
+ 244,
+ 222,
+ 200,
+ 199,
+ 177,
+ 155,
+ 133,
+ 111,
+ 88,
+ 66,
+ 44,
+ 22,
+ 0,
+]
+
+smart27_timesteps = [
+ 999,
+ 976,
+ 952,
+ 928,
+ 905,
+ 882,
+ 858,
+ 857,
+ 810,
+ 762,
+ 715,
+ 714,
+ 572,
+ 429,
+ 428,
+ 286,
+ 285,
+ 238,
+ 190,
+ 143,
+ 142,
+ 118,
+ 95,
+ 71,
+ 47,
+ 24,
+ 0,
+]
+
+smart50_timesteps = [
+ 999,
+ 988,
+ 977,
+ 966,
+ 955,
+ 944,
+ 933,
+ 922,
+ 911,
+ 900,
+ 899,
+ 879,
+ 859,
+ 840,
+ 820,
+ 800,
+ 799,
+ 766,
+ 733,
+ 700,
+ 699,
+ 650,
+ 600,
+ 599,
+ 500,
+ 499,
+ 400,
+ 399,
+ 350,
+ 300,
+ 299,
+ 266,
+ 233,
+ 200,
+ 199,
+ 179,
+ 159,
+ 140,
+ 120,
+ 100,
+ 99,
+ 88,
+ 77,
+ 66,
+ 55,
+ 44,
+ 33,
+ 22,
+ 11,
+ 0,
+]
+
+smart100_timesteps = [
+ 999,
+ 995,
+ 992,
+ 989,
+ 985,
+ 981,
+ 978,
+ 975,
+ 971,
+ 967,
+ 964,
+ 961,
+ 957,
+ 956,
+ 951,
+ 947,
+ 942,
+ 937,
+ 933,
+ 928,
+ 923,
+ 919,
+ 914,
+ 913,
+ 908,
+ 903,
+ 897,
+ 892,
+ 887,
+ 881,
+ 876,
+ 871,
+ 870,
+ 864,
+ 858,
+ 852,
+ 846,
+ 840,
+ 834,
+ 828,
+ 827,
+ 820,
+ 813,
+ 806,
+ 799,
+ 792,
+ 785,
+ 784,
+ 777,
+ 770,
+ 763,
+ 756,
+ 749,
+ 742,
+ 741,
+ 733,
+ 724,
+ 716,
+ 707,
+ 699,
+ 698,
+ 688,
+ 677,
+ 666,
+ 656,
+ 655,
+ 645,
+ 634,
+ 623,
+ 613,
+ 612,
+ 598,
+ 584,
+ 570,
+ 569,
+ 555,
+ 541,
+ 527,
+ 526,
+ 505,
+ 484,
+ 483,
+ 462,
+ 440,
+ 439,
+ 396,
+ 395,
+ 352,
+ 351,
+ 308,
+ 307,
+ 264,
+ 263,
+ 220,
+ 219,
+ 176,
+ 132,
+ 88,
+ 44,
+ 0,
+]
+
+smart185_timesteps = [
+ 999,
+ 997,
+ 995,
+ 992,
+ 990,
+ 988,
+ 986,
+ 984,
+ 981,
+ 979,
+ 977,
+ 975,
+ 972,
+ 970,
+ 968,
+ 966,
+ 964,
+ 961,
+ 959,
+ 957,
+ 956,
+ 954,
+ 951,
+ 949,
+ 946,
+ 944,
+ 941,
+ 939,
+ 936,
+ 934,
+ 931,
+ 929,
+ 926,
+ 924,
+ 921,
+ 919,
+ 916,
+ 914,
+ 913,
+ 910,
+ 907,
+ 905,
+ 902,
+ 899,
+ 896,
+ 893,
+ 891,
+ 888,
+ 885,
+ 882,
+ 879,
+ 877,
+ 874,
+ 871,
+ 870,
+ 867,
+ 864,
+ 861,
+ 858,
+ 855,
+ 852,
+ 849,
+ 846,
+ 843,
+ 840,
+ 837,
+ 834,
+ 831,
+ 828,
+ 827,
+ 824,
+ 821,
+ 817,
+ 814,
+ 811,
+ 808,
+ 804,
+ 801,
+ 798,
+ 795,
+ 791,
+ 788,
+ 785,
+ 784,
+ 780,
+ 777,
+ 774,
+ 770,
+ 766,
+ 763,
+ 760,
+ 756,
+ 752,
+ 749,
+ 746,
+ 742,
+ 741,
+ 737,
+ 733,
+ 730,
+ 726,
+ 722,
+ 718,
+ 714,
+ 710,
+ 707,
+ 703,
+ 699,
+ 698,
+ 694,
+ 690,
+ 685,
+ 681,
+ 677,
+ 673,
+ 669,
+ 664,
+ 660,
+ 656,
+ 655,
+ 650,
+ 646,
+ 641,
+ 636,
+ 632,
+ 627,
+ 622,
+ 618,
+ 613,
+ 612,
+ 607,
+ 602,
+ 596,
+ 591,
+ 586,
+ 580,
+ 575,
+ 570,
+ 569,
+ 563,
+ 557,
+ 551,
+ 545,
+ 539,
+ 533,
+ 527,
+ 526,
+ 519,
+ 512,
+ 505,
+ 498,
+ 491,
+ 484,
+ 483,
+ 474,
+ 466,
+ 457,
+ 449,
+ 440,
+ 439,
+ 428,
+ 418,
+ 407,
+ 396,
+ 395,
+ 381,
+ 366,
+ 352,
+ 351,
+ 330,
+ 308,
+ 307,
+ 286,
+ 264,
+ 263,
+ 242,
+ 220,
+ 219,
+ 176,
+ 175,
+ 132,
+ 131,
+ 88,
+ 44,
+ 0,
+]
+
+super27_timesteps = [
+ 999,
+ 991,
+ 982,
+ 974,
+ 966,
+ 958,
+ 950,
+ 941,
+ 933,
+ 925,
+ 916,
+ 908,
+ 900,
+ 899,
+ 874,
+ 850,
+ 825,
+ 800,
+ 799,
+ 700,
+ 600,
+ 500,
+ 400,
+ 300,
+ 200,
+ 100,
+ 0,
+]
+
+super40_timesteps = [
+ 999,
+ 992,
+ 985,
+ 978,
+ 971,
+ 964,
+ 957,
+ 949,
+ 942,
+ 935,
+ 928,
+ 921,
+ 914,
+ 907,
+ 900,
+ 899,
+ 879,
+ 859,
+ 840,
+ 820,
+ 800,
+ 799,
+ 766,
+ 733,
+ 700,
+ 699,
+ 650,
+ 600,
+ 599,
+ 500,
+ 499,
+ 400,
+ 399,
+ 300,
+ 299,
+ 200,
+ 199,
+ 100,
+ 99,
+ 0,
+]
+
+super100_timesteps = [
+ 999,
+ 996,
+ 992,
+ 989,
+ 985,
+ 982,
+ 979,
+ 975,
+ 972,
+ 968,
+ 965,
+ 961,
+ 958,
+ 955,
+ 951,
+ 948,
+ 944,
+ 941,
+ 938,
+ 934,
+ 931,
+ 927,
+ 924,
+ 920,
+ 917,
+ 914,
+ 910,
+ 907,
+ 903,
+ 900,
+ 899,
+ 891,
+ 884,
+ 876,
+ 869,
+ 861,
+ 853,
+ 846,
+ 838,
+ 830,
+ 823,
+ 815,
+ 808,
+ 800,
+ 799,
+ 788,
+ 777,
+ 766,
+ 755,
+ 744,
+ 733,
+ 722,
+ 711,
+ 700,
+ 699,
+ 688,
+ 677,
+ 666,
+ 655,
+ 644,
+ 633,
+ 622,
+ 611,
+ 600,
+ 599,
+ 585,
+ 571,
+ 557,
+ 542,
+ 528,
+ 514,
+ 500,
+ 499,
+ 485,
+ 471,
+ 457,
+ 442,
+ 428,
+ 414,
+ 400,
+ 399,
+ 379,
+ 359,
+ 340,
+ 320,
+ 300,
+ 299,
+ 279,
+ 259,
+ 240,
+ 220,
+ 200,
+ 199,
+ 166,
+ 133,
+ 100,
+ 99,
+ 66,
+ 33,
+ 0,
+]
diff --git a/diffusers/pipelines/deepfloyd_if/watermark.py b/diffusers/pipelines/deepfloyd_if/watermark.py
new file mode 100644
index 0000000000000000000000000000000000000000..db33dec0ef9ad5909e79358e9d89bdc0ed9c9909
--- /dev/null
+++ b/diffusers/pipelines/deepfloyd_if/watermark.py
@@ -0,0 +1,46 @@
+from typing import List
+
+import PIL
+import torch
+from PIL import Image
+
+from ...configuration_utils import ConfigMixin
+from ...models.modeling_utils import ModelMixin
+from ...utils import PIL_INTERPOLATION
+
+
+class IFWatermarker(ModelMixin, ConfigMixin):
+ def __init__(self):
+ super().__init__()
+
+ self.register_buffer("watermark_image", torch.zeros((62, 62, 4)))
+ self.watermark_image_as_pil = None
+
+ def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None):
+ # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287
+
+ h = images[0].height
+ w = images[0].width
+
+ sample_size = sample_size or h
+
+ coef = min(h / sample_size, w / sample_size)
+ img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
+
+ S1, S2 = 1024**2, img_w * img_h
+ K = (S2 / S1) ** 0.5
+ wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)
+
+ if self.watermark_image_as_pil is None:
+ watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy()
+ watermark_image = Image.fromarray(watermark_image, mode="RGBA")
+ self.watermark_image_as_pil = watermark_image
+
+ wm_img = self.watermark_image_as_pil.resize(
+ (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None
+ )
+
+ for pil_img in images:
+ pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
+
+ return images
diff --git a/diffusers/pipelines/dit/__init__.py b/diffusers/pipelines/dit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef0729cb4905d5e177ba15533375fce50084406
--- /dev/null
+++ b/diffusers/pipelines/dit/__init__.py
@@ -0,0 +1 @@
+from .pipeline_dit import DiTPipeline
diff --git a/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32dd16165b7eaa736a9c50f18b4662a9537b4b4b
Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19eaacc0386c7f7336f2fa59cf5a83c099094217
Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..886af5f803953a880efeb8d02562b834e8a00fa5
Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-310.pyc differ
diff --git a/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0906480834d5f9ccc32ddaa0fca641a56dab659
Binary files /dev/null and b/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-38.pyc differ
diff --git a/diffusers/pipelines/dit/pipeline_dit.py b/diffusers/pipelines/dit/pipeline_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..07fd2835ccf0f2f6d485b45b398dd25049c75228
--- /dev/null
+++ b/diffusers/pipelines/dit/pipeline_dit.py
@@ -0,0 +1,199 @@
+# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
+# William Peebles and Saining Xie
+#
+# Copyright (c) 2021 OpenAI
+# MIT License
+#
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ...models import AutoencoderKL, Transformer2DModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DiTPipeline(DiffusionPipeline):
+ r"""
+ This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ transformer ([`Transformer2DModel`]):
+ Class conditioned Transformer in Diffusion model to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `dit` to denoise the encoded image latents.
+ """
+
+ def __init__(
+ self,
+ transformer: Transformer2DModel,
+ vae: AutoencoderKL,
+ scheduler: KarrasDiffusionSchedulers,
+ id2label: Optional[Dict[int, str]] = None,
+ ):
+ super().__init__()
+ self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
+
+ # create a imagenet -> id dictionary for easier use
+ self.labels = {}
+ if id2label is not None:
+ for key, value in id2label.items():
+ for label in value.split(","):
+ self.labels[label.lstrip().rstrip()] = int(key)
+ self.labels = dict(sorted(self.labels.items()))
+
+ def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
+ r"""
+
+ Map label strings, *e.g.* from ImageNet, to corresponding class ids.
+
+ Parameters:
+ label (`str` or `dict` of `str`): label strings to be mapped to class ids.
+
+ Returns:
+ `list` of `int`: Class ids to be processed by pipeline.
+ """
+
+ if not isinstance(label, list):
+ label = list(label)
+
+ for l in label:
+ if l not in self.labels:
+ raise ValueError(
+ f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
+ )
+
+ return [self.labels[l] for l in label]
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ class_labels: List[int],
+ guidance_scale: float = 4.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ class_labels (List[int]):
+ List of imagenet class labels for the images to be generated.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Scale of the guidance signal.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 250):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
+ """
+
+ batch_size = len(class_labels)
+ latent_size = self.transformer.config.sample_size
+ latent_channels = self.transformer.config.in_channels
+
+ latents = randn_tensor(
+ shape=(batch_size, latent_channels, latent_size, latent_size),
+ generator=generator,
+ device=self._execution_device,
+ dtype=self.transformer.dtype,
+ )
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
+
+ class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
+ class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
+ class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale > 1:
+ half = latent_model_input[: len(latent_model_input) // 2]
+ latent_model_input = torch.cat([half, half], dim=0)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ timesteps = t
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(timesteps, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(latent_model_input.shape[0])
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input, timestep=timesteps, class_labels=class_labels_input
+ ).sample
+
+ # perform guidance
+ if guidance_scale > 1:
+ eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+
+ noise_pred = torch.cat([eps, rest], dim=1)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
+ else:
+ model_output = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
+
+ if guidance_scale > 1:
+ latents, _ = latent_model_input.chunk(2, dim=0)
+ else:
+ latents = latent_model_input
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ samples = self.vae.decode(latents).sample
+
+ samples = (samples / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ samples = self.numpy_to_pil(samples)
+
+ if not return_dict:
+ return (samples,)
+
+ return ImagePipelineOutput(images=samples)
diff --git a/diffusers/pipelines/kandinsky/__init__.py b/diffusers/pipelines/kandinsky/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..242ff799e529abbb268b3562a9671db42d9de37e
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/__init__.py
@@ -0,0 +1,19 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline
+else:
+ from .pipeline_kandinsky import KandinskyPipeline
+ from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
+ from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
+ from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput
+ from .text_encoder import MultilingualCLIP
diff --git a/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a8cc27936214128b0d2d8d9f9a5921c83b2fe6a
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..002c546c28e7144a22e81d1eaa675ae046357498
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09a72cfa0543e44fbd5a4d86cf16504df16ee8bc
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e3b2501783290c5c6f97761f8ad59b1e6bbfb5e
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa36860a2e14dec3318d43da49259d1c78e97fd3
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e938f6793940f0e60bce0455cfc83d9c2a86337
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55005faafefbef98e84fe66a19dfa821f5994da9
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a24732a35092fb398e6b4a081000fc9c5f338e6
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_inpaint.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e43c43de80edbbca5374f14a68dcdc0de74e0bf
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb762bca6343c0b550e5643a6bda4d2dbcc37fcc
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_prior.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff5286fe1e4b74bb3a1623a4ee199c158bd692fe
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bcae69dab418b390b94f5e5ddd72890506a6155
Binary files /dev/null and b/diffusers/pipelines/kandinsky/__pycache__/text_encoder.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
new file mode 100644
index 0000000000000000000000000000000000000000..31edb915e8f76df6910a4a336286cc0b93df8722
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -0,0 +1,418 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+from transformers import (
+ XLMRobertaTokenizer,
+)
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDIMScheduler, DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from .text_encoder import MultilingualCLIP
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
+ >>> import torch
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior")
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "red cat, 4k photo"
+ >>> out = pipe_prior(prompt)
+ >>> image_emb = out.image_embeds
+ >>> negative_image_emb = out.negative_image_embeds
+
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... prompt,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=negative_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=100,
+ ... ).images
+
+ >>> image[0].save("cat.png")
+ ```
+"""
+
+
+def get_new_h_w(h, w, scale_factor=8):
+ new_h = h // scale_factor**2
+ if h % scale_factor**2 != 0:
+ new_h += 1
+ new_w = w // scale_factor**2
+ if w % scale_factor**2 != 0:
+ new_w += 1
+ return new_h * scale_factor, new_w * scale_factor
+
+
+class KandinskyPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`MultilingualCLIP`]):
+ Frozen text-encoder.
+ tokenizer ([`XLMRobertaTokenizer`]):
+ Tokenizer of class
+ scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ text_encoder: MultilingualCLIP,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(device)
+ text_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds, text_encoder_hidden_states = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=text_mask
+ )
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ uncond_text_input_ids = uncond_input.input_ids.to(device)
+ uncond_text_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
+ input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
+ )
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
+ dtype=prompt_embeds.dtype, device=device
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ num_channels_latents = self.unet.config.in_channels
+
+ height, width = get_new_h_w(height, width, self.movq_scale_factor)
+
+ # create initial latent
+ latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ ).prev_sample
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4b04febacc5c751f59c37f6102268f497166d
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -0,0 +1,504 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from transformers import (
+ XLMRobertaTokenizer,
+)
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from .text_encoder import MultilingualCLIP
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
+ >>> from diffusers.utils import load_image
+ >>> import torch
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "A red cartoon frog, 4k"
+ >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
+
+ >>> pipe = KandinskyImg2ImgPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/frog.png"
+ ... )
+
+ >>> image = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=100,
+ ... strength=0.2,
+ ... ).images
+
+ >>> image[0].save("red_frog.png")
+ ```
+"""
+
+
+def get_new_h_w(h, w, scale_factor=8):
+ new_h = h // scale_factor**2
+ if h % scale_factor**2 != 0:
+ new_h += 1
+ new_w = w // scale_factor**2
+ if w % scale_factor**2 != 0:
+ new_w += 1
+ return new_h * scale_factor, new_w * scale_factor
+
+
+def prepare_image(pil_image, w=512, h=512):
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
+ arr = np.array(pil_image.convert("RGB"))
+ arr = arr.astype(np.float32) / 127.5 - 1
+ arr = np.transpose(arr, [2, 0, 1])
+ image = torch.from_numpy(arr).unsqueeze(0)
+ return image
+
+
+class KandinskyImg2ImgPipeline(DiffusionPipeline):
+ """
+ Pipeline for image-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`MultilingualCLIP`]):
+ Frozen text-encoder.
+ tokenizer ([`XLMRobertaTokenizer`]):
+ Tokenizer of class
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ image encoder and decoder
+ """
+
+ def __init__(
+ self,
+ text_encoder: MultilingualCLIP,
+ movq: VQModel,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+
+ shape = latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ latents = self.add_noise(latents, noise, latent_timestep)
+ return latents
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(device)
+ text_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds, text_encoder_hidden_states = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=text_mask
+ )
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ uncond_text_input_ids = uncond_input.input_ids.to(device)
+ uncond_text_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
+ input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
+ )
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32)
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+
+ return noisy_samples
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
+ image_embeds: torch.FloatTensor,
+ negative_image_embeds: torch.FloatTensor,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ strength: float = 0.3,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor`, `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ strength (`float`, *optional*, defaults to 0.3):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ # 1. Define call parameters
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 2. get text and image embeddings
+ prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
+ dtype=prompt_embeds.dtype, device=device
+ )
+
+ # 3. pre-processing initial image
+ if not isinstance(image, list):
+ image = [image]
+ if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
+ raise ValueError(
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
+ )
+
+ image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
+
+ latents = self.movq.encode(image)["latents"]
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ # the formular to calculate timestep for add_noise is taken from the original kandinsky repo
+ latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2
+
+ latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device)
+
+ num_channels_latents = self.unet.config.in_channels
+
+ height, width = get_new_h_w(height, width, self.movq_scale_factor)
+
+ # 5. Create initial latent
+ latents = self.prepare_latents(
+ latents,
+ latent_timestep,
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ self.scheduler,
+ )
+
+ # 6. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ ).prev_sample
+
+ # 7. post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4a7dc225947d5f70ad05c13886c79c4b34ee70b
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -0,0 +1,630 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import (
+ XLMRobertaTokenizer,
+)
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from .text_encoder import MultilingualCLIP
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
+ >>> from diffusers.utils import load_image
+ >>> import torch
+ >>> import numpy as np
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "a hat"
+ >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
+
+ >>> pipe = KandinskyInpaintPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+
+ >>> mask = np.ones((768, 768), dtype=np.float32)
+ >>> mask[:250, 250:-250] = 0
+
+ >>> out = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... mask_image=mask,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=50,
+ ... )
+
+ >>> image = out.images[0]
+ >>> image.save("cat_with_hat.png")
+ ```
+"""
+
+
+def get_new_h_w(h, w, scale_factor=8):
+ new_h = h // scale_factor**2
+ if h % scale_factor**2 != 0:
+ new_h += 1
+ new_w = w // scale_factor**2
+ if w % scale_factor**2 != 0:
+ new_w += 1
+ return new_h * scale_factor, new_w * scale_factor
+
+
+def prepare_mask(masks):
+ prepared_masks = []
+ for mask in masks:
+ old_mask = deepcopy(mask)
+ for i in range(mask.shape[1]):
+ for j in range(mask.shape[2]):
+ if old_mask[0][i][j] == 1:
+ continue
+ if i != 0:
+ mask[:, i - 1, j] = 0
+ if j != 0:
+ mask[:, i, j - 1] = 0
+ if i != 0 and j != 0:
+ mask[:, i - 1, j - 1] = 0
+ if i != mask.shape[1] - 1:
+ mask[:, i + 1, j] = 0
+ if j != mask.shape[2] - 1:
+ mask[:, i, j + 1] = 0
+ if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
+ mask[:, i + 1, j + 1] = 0
+ prepared_masks.append(mask)
+ return torch.stack(prepared_masks, dim=0)
+
+
+def prepare_mask_and_masked_image(image, mask, height, width):
+ r"""
+ Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will
+ be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
+ the ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ return mask, image
+
+
+class KandinskyInpaintPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-guided image inpainting using Kandinsky2.1
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`MultilingualCLIP`]):
+ Frozen text-encoder.
+ tokenizer ([`XLMRobertaTokenizer`]):
+ Tokenizer of class
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ image encoder and decoder
+ """
+
+ def __init__(
+ self,
+ text_encoder: MultilingualCLIP,
+ movq: VQModel,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ text_encoder=text_encoder,
+ movq=movq,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(device)
+ text_mask = text_inputs.attention_mask.to(device)
+
+ prompt_embeds, text_encoder_hidden_states = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=text_mask
+ )
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ uncond_text_input_ids = uncond_input.input_ids.to(device)
+ uncond_text_mask = uncond_input.attention_mask.to(device)
+
+ negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
+ input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
+ )
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
+ image_embeds: torch.FloatTensor,
+ negative_image_embeds: torch.FloatTensor,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor`, `PIL.Image.Image` or `np.ndarray`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`PIL.Image.Image`,`torch.FloatTensor` or `np.ndarray`):
+ `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the
+ image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the
+ expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL
+ image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it
+ will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected
+ shape is `(H, W)`.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+
+ # Define call parameters
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
+ dtype=prompt_embeds.dtype, device=device
+ )
+
+ # preprocess image and mask
+ mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
+
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
+ image = self.movq.encode(image)["latents"]
+
+ mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device)
+
+ image_shape = tuple(image.shape[-2:])
+ mask_image = F.interpolate(
+ mask_image,
+ image_shape,
+ mode="nearest",
+ )
+ mask_image = prepare_mask(mask_image)
+ masked_image = image * mask_image
+
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
+ masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ mask_image = mask_image.repeat(2, 1, 1, 1)
+ masked_image = masked_image.repeat(2, 1, 1, 1)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ num_channels_latents = self.movq.config.latent_channels
+
+ # get h, w for latents
+ sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor)
+
+ # create initial latent
+ latents = self.prepare_latents(
+ (batch_size, num_channels_latents, sample_height, sample_width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ # Check that sizes of mask, masked image and latents match with expected
+ num_channels_mask = mask_image.shape[1]
+ num_channels_masked_image = masked_image.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
+
+ added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ ).prev_sample
+
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d99180b3d1106b553b69529e7ede3bfdef4b65f
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -0,0 +1,541 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...models import PriorTransformer
+from ...pipelines import DiffusionPipeline
+from ...schedulers import UnCLIPScheduler
+from ...utils import (
+ BaseOutput,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
+ >>> import torch
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "red cat, 4k photo"
+ >>> out = pipe_prior(prompt)
+ >>> image_emb = out.image_embeds
+ >>> negative_image_emb = out.negative_image_embeds
+
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... prompt,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=negative_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=100,
+ ... ).images
+
+ >>> image[0].save("cat.png")
+ ```
+"""
+
+EXAMPLE_INTERPOLATE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
+ >>> from diffusers.utils import load_image
+ >>> import PIL
+
+ >>> import torch
+ >>> from torchvision import transforms
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> img1 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+
+ >>> img2 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/starry_night.jpeg"
+ ... )
+
+ >>> images_texts = ["a cat", img1, img2]
+ >>> weights = [0.3, 0.3, 0.4]
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
+
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... "",
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=150,
+ ... ).images[0]
+
+ >>> image.save("starry_cat.png")
+ ```
+"""
+
+
+@dataclass
+class KandinskyPriorPipelineOutput(BaseOutput):
+ """
+ Output class for KandinskyPriorPipeline.
+
+ Args:
+ image_embeds (`torch.FloatTensor`)
+ clip image embeddings for text prompt
+ negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
+ clip image embeddings for unconditional tokens
+ """
+
+ image_embeds: Union[torch.FloatTensor, np.ndarray]
+ negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
+
+
+class KandinskyPriorPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating image prior for Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen image-encoder.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`UnCLIPScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ """
+
+ _exclude_from_cpu_offload = ["prior"]
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ image_encoder: CLIPVisionModelWithProjection,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ scheduler: UnCLIPScheduler,
+ image_processor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
+ def interpolate(
+ self,
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
+ weights: List[float],
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ negative_prior_prompt: Optional[str] = None,
+ negative_prompt: Union[str] = "",
+ guidance_scale: float = 4.0,
+ device=None,
+ ):
+ """
+ Function invoked when using the prior pipeline for interpolation.
+
+ Args:
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
+ list of prompts and images to guide the image generation.
+ weights: (`List[float]`):
+ list of weights for each condition in `images_and_prompts`
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ negative_prior_prompt (`str`, *optional*):
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ device = device or self.device
+
+ if len(images_and_prompts) != len(weights):
+ raise ValueError(
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
+ )
+
+ image_embeddings = []
+ for cond, weight in zip(images_and_prompts, weights):
+ if isinstance(cond, str):
+ image_emb = self(
+ cond,
+ num_inference_steps=num_inference_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ negative_prompt=negative_prior_prompt,
+ guidance_scale=guidance_scale,
+ ).image_embeds
+
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
+ if isinstance(cond, PIL.Image.Image):
+ cond = (
+ self.image_processor(cond, return_tensors="pt")
+ .pixel_values[0]
+ .unsqueeze(0)
+ .to(dtype=self.image_encoder.dtype, device=device)
+ )
+
+ image_emb = self.image_encoder(cond)["image_embeds"]
+
+ else:
+ raise ValueError(
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
+ )
+
+ image_embeddings.append(image_emb * weight)
+
+ image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
+
+ out_zero = self(
+ negative_prompt,
+ num_inference_steps=num_inference_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ negative_prompt=negative_prior_prompt,
+ guidance_scale=guidance_scale,
+ )
+ zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def get_zero_embed(self, batch_size=1, device=None):
+ device = device or self.device
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
+ device=device, dtype=self.image_encoder.dtype
+ )
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
+ return zero_image_emb
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ guidance_scale: float = 4.0,
+ output_type: Optional[str] = "pt",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
+ (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ elif not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # if the negative prompt is defined we double the batch size to
+ # directly retrieve the negative prompt embedding
+ if negative_prompt is not None:
+ prompt = prompt + negative_prompt
+ negative_prompt = 2 * negative_prompt
+
+ device = self._execution_device
+
+ batch_size = len(prompt)
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # prior
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ prior_timesteps_tensor = self.scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ latents = self.scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ latents = self.prior.post_process_latents(latents)
+
+ image_embeddings = latents
+
+ # if negative prompt has been defined, we retrieve split the image embedding into two
+ if negative_prompt is None:
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
+ else:
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
+
+ if output_type not in ["pt", "np"]:
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
+
+ if output_type == "np":
+ image_embeddings = image_embeddings.cpu().numpy()
+ zero_embeds = zero_embeds.cpu().numpy()
+
+ if not return_dict:
+ return (image_embeddings, zero_embeds)
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
diff --git a/diffusers/pipelines/kandinsky/text_encoder.py b/diffusers/pipelines/kandinsky/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa0029f00ca22818819d5b76b57ec489c6da1d6
--- /dev/null
+++ b/diffusers/pipelines/kandinsky/text_encoder.py
@@ -0,0 +1,27 @@
+import torch
+from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
+
+
+class MCLIPConfig(XLMRobertaConfig):
+ model_type = "M-CLIP"
+
+ def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs):
+ self.transformerDimensions = transformerDimSize
+ self.numDims = imageDimSize
+ super().__init__(**kwargs)
+
+
+class MultilingualCLIP(PreTrainedModel):
+ config_class = MCLIPConfig
+
+ def __init__(self, config, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.transformer = XLMRobertaModel(config)
+ self.LinearTransformation = torch.nn.Linear(
+ in_features=config.transformerDimensions, out_features=config.numDims
+ )
+
+ def forward(self, input_ids, attention_mask):
+ embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
+ embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
+ return self.LinearTransformation(embs2), embs
diff --git a/diffusers/pipelines/kandinsky2_2/__init__.py b/diffusers/pipelines/kandinsky2_2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..648164b9f1ba657feb686a70ad2a4e367f898e20
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/__init__.py
@@ -0,0 +1,7 @@
+from .pipeline_kandinsky2_2 import KandinskyV22Pipeline
+from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline
+from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline
+from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline
+from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline
+from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline
+from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eeb8fdfef5461688b1e4682c7361726a12e39c75
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f026433b3fcc6784da85c9aefe92f4f9b95aab50
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f21ed28416d19815453a4d783d82e5b53cd8399
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..483d98bd87ff8442a8b8d56a95530b4706d61b9d
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a47f9c194095fb7d56d7406400ce639a50e8099
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8164953faae83a22c730c0a3960502adc358bf22
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98fc4f8679d66baa34b966668bf2432cfee2e1ca
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..078d473864639d099f0c996a31be139ce1860757
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_controlnet_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99a294348529832329145fdab092cd925a3ace94
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b658ffa0c19bc29033b431faed31fa3149d93346
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..709546e6617e18d1969c619c0a9653aae74df741
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0869b29f128d52755b38f7d98203a642da4ff9d
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_inpainting.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..215d89367edda2d601230003afee72f70c378c2f
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45add2c13fab29e27f453d8d3b6f4d9b1df58b55
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..057c72e4a00e4af38ff82870c1d9a30f049dd401
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-310.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..072ff2d46a0eb296655632b1184485c351ca57ae
Binary files /dev/null and b/diffusers/pipelines/kandinsky2_2/__pycache__/pipeline_kandinsky2_2_prior_emb2emb.cpython-38.pyc differ
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6db5260b04d82d7301d887dc3095b7b64aff0ab5
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -0,0 +1,277 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
+ >>> import torch
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
+ >>> pipe_prior.to("cuda")
+ >>> prompt = "red cat, 4k photo"
+ >>> out = pipe_prior(prompt)
+ >>> image_emb = out.image_embeds
+ >>> zero_image_emb = out.negative_image_embeds
+ >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
+ >>> pipe.to("cuda")
+ >>> image = pipe(
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=50,
+ ... ).images
+ >>> image[0].save("cat.png")
+ ```
+"""
+
+
+def downscale_height_and_width(height, width, scale_factor=8):
+ new_height = height // scale_factor**2
+ if height % scale_factor**2 != 0:
+ new_height += 1
+ new_width = width // scale_factor**2
+ if width % scale_factor**2 != 0:
+ new_width += 1
+ return new_height * scale_factor, new_width * scale_factor
+
+
+class KandinskyV22Pipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ Function invoked when calling the pipeline for generation.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ batch_size = image_embeds.shape[0] * num_images_per_prompt
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ num_channels_latents = self.unet.config.in_channels
+
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
+
+ # create initial latent
+ latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ image_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ )[0]
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6fc269117470dd0130b3546ece0f509e936126e
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -0,0 +1,331 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
+ >>> from transformers import pipeline
+ >>> from diffusers.utils import load_image
+
+
+ >>> def make_hint(image, depth_estimator):
+ ... image = depth_estimator(image)["depth"]
+ ... image = np.array(image)
+ ... image = image[:, :, None]
+ ... image = np.concatenate([image, image, image], axis=2)
+ ... detected_map = torch.from_numpy(image).float() / 255.0
+ ... hint = detected_map.permute(2, 0, 1)
+ ... return hint
+
+
+ >>> depth_estimator = pipeline("depth-estimation")
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior = pipe_prior.to("cuda")
+
+ >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+
+ >>> img = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... ).resize((768, 768))
+
+ >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
+
+ >>> prompt = "A robot, 4k photo"
+ >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(43)
+
+ >>> image_emb, zero_image_emb = pipe_prior(
+ ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
+ ... ).to_tuple()
+
+ >>> images = pipe(
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... hint=hint,
+ ... num_inference_steps=50,
+ ... generator=generator,
+ ... height=768,
+ ... width=768,
+ ... ).images
+
+ >>> images[0].save("robot_cat.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
+def downscale_height_and_width(height, width, scale_factor=8):
+ new_height = height // scale_factor**2
+ if height % scale_factor**2 != 0:
+ new_height += 1
+ new_width = width // scale_factor**2
+ if width % scale_factor**2 != 0:
+ new_width += 1
+ return new_height * scale_factor, new_width * scale_factor
+
+
+class KandinskyV22ControlnetPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ hint: torch.FloatTensor,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ hint (`torch.FloatTensor`):
+ The controlnet condition.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+ if isinstance(hint, list):
+ hint = torch.cat(hint, dim=0)
+
+ batch_size = image_embeds.shape[0] * num_images_per_prompt
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
+ hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ num_channels_latents = self.movq.config.latent_channels
+
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
+
+ # create initial latent
+ latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ image_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ )[0]
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a25624b7267bd3f8d82fa7fc6514920193adfc0
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -0,0 +1,393 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
+ >>> from transformers import pipeline
+ >>> from diffusers.utils import load_image
+
+
+ >>> def make_hint(image, depth_estimator):
+ ... image = depth_estimator(image)["depth"]
+ ... image = np.array(image)
+ ... image = image[:, :, None]
+ ... image = np.concatenate([image, image, image], axis=2)
+ ... detected_map = torch.from_numpy(image).float() / 255.0
+ ... hint = detected_map.permute(2, 0, 1)
+ ... return hint
+
+
+ >>> depth_estimator = pipeline("depth-estimation")
+
+ >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior = pipe_prior.to("cuda")
+
+ >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> img = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... ).resize((768, 768))
+
+
+ >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
+
+ >>> prompt = "A robot, 4k photo"
+ >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(43)
+
+ >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
+ >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
+
+ >>> images = pipe(
+ ... image=img,
+ ... strength=0.5,
+ ... image_embeds=img_emb.image_embeds,
+ ... negative_image_embeds=negative_emb.image_embeds,
+ ... hint=hint,
+ ... num_inference_steps=50,
+ ... generator=generator,
+ ... height=768,
+ ... width=768,
+ ... ).images
+
+ >>> images[0].save("robot_cat.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
+def downscale_height_and_width(height, width, scale_factor=8):
+ new_height = height // scale_factor**2
+ if height % scale_factor**2 != 0:
+ new_height += 1
+ new_width = width // scale_factor**2
+ if width % scale_factor**2 != 0:
+ new_width += 1
+ return new_height * scale_factor, new_width * scale_factor
+
+
+# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
+def prepare_image(pil_image, w=512, h=512):
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
+ arr = np.array(pil_image.convert("RGB"))
+ arr = arr.astype(np.float32) / 127.5 - 1
+ arr = np.transpose(arr, [2, 0, 1])
+ image = torch.from_numpy(arr).unsqueeze(0)
+ return image
+
+
+class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
+ """
+ Pipeline for image-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.movq.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.movq.config.scaling_factor * init_latents
+
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ hint: torch.FloatTensor,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ strength: float = 0.3,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ hint (`torch.FloatTensor`):
+ The controlnet condition.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+ if isinstance(hint, list):
+ hint = torch.cat(hint, dim=0)
+
+ batch_size = image_embeds.shape[0]
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
+ hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
+
+ if not isinstance(image, list):
+ image = [image]
+ if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
+ raise ValueError(
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
+ )
+
+ image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = image.to(dtype=image_embeds.dtype, device=device)
+
+ latents = self.movq.encode(image)["latents"]
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
+ latents = self.prepare_latents(
+ latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
+ )
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ )[0]
+
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..26976ad0c925f09f2f550741c4c7a26f83c68700
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -0,0 +1,357 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline
+ >>> from diffusers.utils import load_image
+ >>> import torch
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "A red cartoon frog, 4k"
+ >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
+
+ >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/frog.png"
+ ... )
+
+ >>> image = pipe(
+ ... image=init_image,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=100,
+ ... strength=0.2,
+ ... ).images
+
+ >>> image[0].save("red_frog.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
+def downscale_height_and_width(height, width, scale_factor=8):
+ new_height = height // scale_factor**2
+ if height % scale_factor**2 != 0:
+ new_height += 1
+ new_width = width // scale_factor**2
+ if width % scale_factor**2 != 0:
+ new_width += 1
+ return new_height * scale_factor, new_width * scale_factor
+
+
+# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
+def prepare_image(pil_image, w=512, h=512):
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
+ arr = np.array(pil_image.convert("RGB"))
+ arr = arr.astype(np.float32) / 127.5 - 1
+ arr = np.transpose(arr, [2, 0, 1])
+ image = torch.from_numpy(arr).unsqueeze(0)
+ return image
+
+
+class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
+ """
+ Pipeline for image-to-image generation using Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.movq.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.movq.config.scaling_factor * init_latents
+
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ strength: float = 0.3,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ batch_size = image_embeds.shape[0]
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
+
+ if not isinstance(image, list):
+ image = [image]
+ if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
+ raise ValueError(
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
+ )
+
+ image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = image.to(dtype=image_embeds.dtype, device=device)
+
+ latents = self.movq.encode(image)["latents"]
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
+ latents = self.prepare_latents(
+ latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
+ )
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ added_cond_kwargs = {"image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ )[0]
+
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..27bd01984dbf2bfb6d1bf3d848dc998c87d3077a
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -0,0 +1,490 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+from ...models import UNet2DConditionModel, VQModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import DDPMScheduler
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline
+ >>> from diffusers.utils import load_image
+ >>> import torch
+ >>> import numpy as np
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "a hat"
+ >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
+
+ >>> pipe = KandinskyV22InpaintPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+
+ >>> mask = np.ones((768, 768), dtype=np.float32)
+ >>> mask[:250, 250:-250] = 0
+
+ >>> out = pipe(
+ ... image=init_image,
+ ... mask_image=mask,
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=50,
+ ... )
+
+ >>> image = out.images[0]
+ >>> image.save("cat_with_hat.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
+def downscale_height_and_width(height, width, scale_factor=8):
+ new_height = height // scale_factor**2
+ if height % scale_factor**2 != 0:
+ new_height += 1
+ new_width = width // scale_factor**2
+ if width % scale_factor**2 != 0:
+ new_width += 1
+ return new_height * scale_factor, new_width * scale_factor
+
+
+# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask
+def prepare_mask(masks):
+ prepared_masks = []
+ for mask in masks:
+ old_mask = deepcopy(mask)
+ for i in range(mask.shape[1]):
+ for j in range(mask.shape[2]):
+ if old_mask[0][i][j] == 1:
+ continue
+ if i != 0:
+ mask[:, i - 1, j] = 0
+ if j != 0:
+ mask[:, i, j - 1] = 0
+ if i != 0 and j != 0:
+ mask[:, i - 1, j - 1] = 0
+ if i != mask.shape[1] - 1:
+ mask[:, i + 1, j] = 0
+ if j != mask.shape[2] - 1:
+ mask[:, i, j + 1] = 0
+ if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
+ mask[:, i + 1, j + 1] = 0
+ prepared_masks.append(mask)
+ return torch.stack(prepared_masks, dim=0)
+
+
+# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image
+def prepare_mask_and_masked_image(image, mask, height, width):
+ r"""
+ Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will
+ be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
+ the ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ return mask, image
+
+
+class KandinskyV22InpaintPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-guided image inpainting using Kandinsky2.1
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ scheduler ([`DDIMScheduler`]):
+ A scheduler to be used in combination with `unet` to generate image latents.
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the image embedding.
+ movq ([`VQModel`]):
+ MoVQ Decoder to generate the image from the latents.
+ """
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: DDPMScheduler,
+ movq: VQModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ scheduler=scheduler,
+ movq=movq,
+ )
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.movq]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ Function invoked when calling the pipeline for generation.
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`np.array`):
+ Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while
+ white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single
+ channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3,
+ so the expected shape would be `(B, H, W, 1)`.
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
+ """
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(image_embeds, list):
+ image_embeds = torch.cat(image_embeds, dim=0)
+ batch_size = image_embeds.shape[0] * num_images_per_prompt
+ if isinstance(negative_image_embeds, list):
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
+
+ if do_classifier_free_guidance:
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ # preprocess image and mask
+ mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
+
+ image = image.to(dtype=image_embeds.dtype, device=device)
+ image = self.movq.encode(image)["latents"]
+
+ mask_image = mask_image.to(dtype=image_embeds.dtype, device=device)
+
+ image_shape = tuple(image.shape[-2:])
+ mask_image = F.interpolate(
+ mask_image,
+ image_shape,
+ mode="nearest",
+ )
+ mask_image = prepare_mask(mask_image)
+ masked_image = image * mask_image
+
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
+ masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ mask_image = mask_image.repeat(2, 1, 1, 1)
+ masked_image = masked_image.repeat(2, 1, 1, 1)
+
+ num_channels_latents = self.movq.config.latent_channels
+
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
+
+ # create initial latent
+ latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ image_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+ noise = torch.clone(latents)
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
+
+ added_cond_kwargs = {"image_embeds": image_embeds}
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=None,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ _, variance_pred_text = variance_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
+
+ if not (
+ hasattr(self.scheduler.config, "variance_type")
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
+ ):
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ generator=generator,
+ )[0]
+ init_latents_proper = image[:1]
+ init_mask = mask_image[:1]
+
+ if i < len(timesteps_tensor) - 1:
+ noise_timestep = timesteps_tensor[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = init_mask * init_latents_proper + (1 - init_mask) * latents
+ # post-processing
+ latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bfc6523cdf944a0cbc31e159eb0199032b53216
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -0,0 +1,501 @@
+from typing import List, Optional, Union
+
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...models import PriorTransformer
+from ...pipelines import DiffusionPipeline
+from ...schedulers import UnCLIPScheduler
+from ...utils import (
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..kandinsky import KandinskyPriorPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
+ >>> import torch
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
+ >>> pipe_prior.to("cuda")
+ >>> prompt = "red cat, 4k photo"
+ >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple()
+
+ >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
+ >>> pipe.to("cuda")
+ >>> image = pipe(
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=negative_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=50,
+ ... ).images
+ >>> image[0].save("cat.png")
+ ```
+"""
+
+EXAMPLE_INTERPOLATE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
+ >>> from diffusers.utils import load_image
+ >>> import PIL
+ >>> import torch
+ >>> from torchvision import transforms
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+ >>> img1 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+ >>> img2 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/starry_night.jpeg"
+ ... )
+ >>> images_texts = ["a cat", img1, img2]
+ >>> weights = [0.3, 0.3, 0.4]
+ >>> out = pipe_prior.interpolate(images_texts, weights)
+ >>> pipe = KandinskyV22Pipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+ >>> image = pipe(
+ ... image_embeds=out.image_embeds,
+ ... negative_image_embeds=out.negative_image_embeds,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=50,
+ ... ).images[0]
+ >>> image.save("starry_cat.png")
+ ```
+"""
+
+
+class KandinskyV22PriorPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating image prior for Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen image-encoder.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`UnCLIPScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ image_processor ([`CLIPImageProcessor`]):
+ A image_processor to be used to preprocess image from clip.
+ """
+
+ _exclude_from_cpu_offload = ["prior"]
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ image_encoder: CLIPVisionModelWithProjection,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ scheduler: UnCLIPScheduler,
+ image_processor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
+ def interpolate(
+ self,
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
+ weights: List[float],
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ negative_prior_prompt: Optional[str] = None,
+ negative_prompt: Union[str] = "",
+ guidance_scale: float = 4.0,
+ device=None,
+ ):
+ """
+ Function invoked when using the prior pipeline for interpolation.
+
+ Args:
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
+ list of prompts and images to guide the image generation.
+ weights: (`List[float]`):
+ list of weights for each condition in `images_and_prompts`
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ negative_prior_prompt (`str`, *optional*):
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ device = device or self.device
+
+ if len(images_and_prompts) != len(weights):
+ raise ValueError(
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
+ )
+
+ image_embeddings = []
+ for cond, weight in zip(images_and_prompts, weights):
+ if isinstance(cond, str):
+ image_emb = self(
+ cond,
+ num_inference_steps=num_inference_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ negative_prompt=negative_prior_prompt,
+ guidance_scale=guidance_scale,
+ ).image_embeds.unsqueeze(0)
+
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
+ if isinstance(cond, PIL.Image.Image):
+ cond = (
+ self.image_processor(cond, return_tensors="pt")
+ .pixel_values[0]
+ .unsqueeze(0)
+ .to(dtype=self.image_encoder.dtype, device=device)
+ )
+
+ image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0)
+
+ else:
+ raise ValueError(
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
+ )
+
+ image_embeddings.append(image_emb * weight)
+
+ image_emb = torch.cat(image_embeddings).sum(dim=0)
+
+ out_zero = self(
+ negative_prompt,
+ num_inference_steps=num_inference_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ negative_prompt=negative_prior_prompt,
+ guidance_scale=guidance_scale,
+ )
+ zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
+ def get_zero_embed(self, batch_size=1, device=None):
+ device = device or self.device
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
+ device=device, dtype=self.image_encoder.dtype
+ )
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
+ return zero_image_emb
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ guidance_scale: float = 4.0,
+ output_type: Optional[str] = "pt", # pt only
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
+ (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ elif not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # if the negative prompt is defined we double the batch size to
+ # directly retrieve the negative prompt embedding
+ if negative_prompt is not None:
+ prompt = prompt + negative_prompt
+ negative_prompt = 2 * negative_prompt
+
+ device = self._execution_device
+
+ batch_size = len(prompt)
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # prior
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ prior_timesteps_tensor = self.scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ latents = self.scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ latents = self.prior.post_process_latents(latents)
+
+ image_embeddings = latents
+
+ # if negative prompt has been defined, we retrieve split the image embedding into two
+ if negative_prompt is None:
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
+ else:
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
+
+ if output_type not in ["pt", "np"]:
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
+
+ if output_type == "np":
+ image_embeddings = image_embeddings.cpu().numpy()
+ zero_embeds = zero_embeds.cpu().numpy()
+
+ if not return_dict:
+ return (image_embeddings, zero_embeds)
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
diff --git a/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcbeb78bac1020c29e3c1c22dac4dc455714b9ab
--- /dev/null
+++ b/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -0,0 +1,565 @@
+from typing import List, Optional, Union
+
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...models import PriorTransformer
+from ...pipelines import DiffusionPipeline
+from ...schedulers import UnCLIPScheduler
+from ...utils import (
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..kandinsky import KandinskyPriorPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline
+ >>> import torch
+
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> prompt = "red cat, 4k photo"
+ >>> img = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+ >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple()
+
+ >>> pipe = KandinskyPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16"
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=negative_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=100,
+ ... ).images
+
+ >>> image[0].save("cat.png")
+ ```
+"""
+
+EXAMPLE_INTERPOLATE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline
+ >>> from diffusers.utils import load_image
+ >>> import PIL
+
+ >>> import torch
+ >>> from torchvision import transforms
+
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
+ ... )
+ >>> pipe_prior.to("cuda")
+
+ >>> img1 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/cat.png"
+ ... )
+
+ >>> img2 = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ ... "/kandinsky/starry_night.jpeg"
+ ... )
+
+ >>> images_texts = ["a cat", img1, img2]
+ >>> weights = [0.3, 0.3, 0.4]
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
+
+ >>> pipe = KandinskyV22Pipeline.from_pretrained(
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... image_embeds=image_emb,
+ ... negative_image_embeds=zero_image_emb,
+ ... height=768,
+ ... width=768,
+ ... num_inference_steps=150,
+ ... ).images[0]
+
+ >>> image.save("starry_cat.png")
+ ```
+"""
+
+
+class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating image prior for Kandinsky
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen image-encoder.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`UnCLIPScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ """
+
+ _exclude_from_cpu_offload = ["prior"]
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ image_encoder: CLIPVisionModelWithProjection,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ scheduler: UnCLIPScheduler,
+ image_processor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
+ def interpolate(
+ self,
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
+ weights: List[float],
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ negative_prior_prompt: Optional[str] = None,
+ negative_prompt: Union[str] = "",
+ guidance_scale: float = 4.0,
+ device=None,
+ ):
+ """
+ Function invoked when using the prior pipeline for interpolation.
+
+ Args:
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
+ list of prompts and images to guide the image generation.
+ weights: (`List[float]`):
+ list of weights for each condition in `images_and_prompts`
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ negative_prior_prompt (`str`, *optional*):
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
+ `guidance_scale` is less than `1`).
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ device = device or self.device
+
+ if len(images_and_prompts) != len(weights):
+ raise ValueError(
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
+ )
+
+ image_embeddings = []
+ for cond, weight in zip(images_and_prompts, weights):
+ if isinstance(cond, str):
+ image_emb = self(
+ cond,
+ num_inference_steps=num_inference_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ negative_prompt=negative_prior_prompt,
+ guidance_scale=guidance_scale,
+ ).image_embeds.unsqueeze(0)
+
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
+ image_emb = self._encode_image(
+ cond, device=device, num_images_per_prompt=num_images_per_prompt
+ ).unsqueeze(0)
+
+ else:
+ raise ValueError(
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
+ )
+
+ image_embeddings.append(image_emb * weight)
+
+ image_emb = torch.cat(image_embeddings).sum(dim=0)
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb))
+
+ def _encode_image(
+ self,
+ image: Union[torch.Tensor, List[PIL.Image.Image]],
+ device,
+ num_images_per_prompt,
+ ):
+ if not isinstance(image, torch.Tensor):
+ image = self.image_processor(image, return_tensors="pt").pixel_values.to(
+ dtype=self.image_encoder.dtype, device=device
+ )
+
+ image_emb = self.image_encoder(image)["image_embeds"] # B, D
+ image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0)
+ image_emb.to(device=device)
+
+ return image_emb
+
+ def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ emb = emb.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ init_latents = emb
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
+ def get_zero_embed(self, batch_size=1, device=None):
+ device = device or self.device
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
+ device=device, dtype=self.image_encoder.dtype
+ )
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
+ return zero_image_emb
+
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]],
+ strength: float = 0.3,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ guidance_scale: float = 4.0,
+ output_type: Optional[str] = "pt", # pt only
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added.
+ emb (`torch.FloatTensor`):
+ The image embedding.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
+ (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`KandinskyPriorPipelineOutput`] or `tuple`
+ """
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ elif not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # if the negative prompt is defined we double the batch size to
+ # directly retrieve the negative prompt embedding
+ if negative_prompt is not None:
+ prompt = prompt + negative_prompt
+ negative_prompt = 2 * negative_prompt
+
+ device = self._execution_device
+
+ batch_size = len(prompt)
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ if not isinstance(image, List):
+ image = [image]
+
+ if isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ if isinstance(image, torch.Tensor) and image.ndim == 2:
+ # allow user to pass image_embeds directly
+ image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0)
+ elif isinstance(image, torch.Tensor) and image.ndim != 4:
+ raise ValueError(
+ f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}"
+ )
+ else:
+ image_embeds = self._encode_image(image, device, num_images_per_prompt)
+
+ # prior
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ latents = image_embeds
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+ latents = self.prepare_latents(
+ latents,
+ latent_timestep,
+ batch_size // num_images_per_prompt,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == timesteps.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = timesteps[i + 1]
+
+ latents = self.scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ latents = self.prior.post_process_latents(latents)
+
+ image_embeddings = latents
+
+ # if negative prompt has been defined, we retrieve split the image embedding into two
+ if negative_prompt is None:
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
+ else:
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
+
+ if output_type not in ["pt", "np"]:
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
+
+ if output_type == "np":
+ image_embeddings = image_embeddings.cpu().numpy()
+ zero_embeds = zero_embeds.cpu().numpy()
+
+ if not return_dict:
+ return (image_embeddings, zero_embeds)
+
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cce9a89bcbeaac8468d75e9d16c9d3731f738c7
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/__init__.py
@@ -0,0 +1,6 @@
+from ...utils import is_transformers_available
+from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
+
+
+if is_transformers_available():
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e25ca759f420cc63716e58b9b12a507f6073b84a
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cf103f1700b6b5f5aabf86a2f04fc7a45c61432
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ebabe73d24df451cd1d583965213afde489396f
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df45fb238fc00cfa5cc10dc9a5e168b1cc461e1e
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..257cd82a2c07c2617929ba61c9be2e1469200506
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a5d2591784cba978dab74cac7a11b867dfef006
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-38.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab7c28b96cc85724d925928ba5b065a9ba2bf095
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -0,0 +1,726 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.utils import logging
+
+from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class LDMTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: Union[VQModel, AutoencoderKL],
+ bert: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ unet: Union[UNet2DModel, UNet2DConditionModel],
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 1.0,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
+ the, usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get unconditional embeddings for classifier free guidance
+ if guidance_scale != 1.0:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
+ )
+ negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0]
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
+ prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0]
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(
+ latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self._execution_device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale == 1.0:
+ # guidance_scale of 1 means no guidance
+ latents_input = latents
+ context = prompt_embeds
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = torch.cat([latents] * 2)
+ context = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
+ # perform guidance
+ if guidance_scale != 1.0:
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vqvae.config.scaling_factor * latents
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+################################################################################
+# Code for the text transformer model
+################################################################################
+""" PyTorch LDMBERT model."""
+
+
+logger = logging.get_logger(__name__)
+
+LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "ldm-bert",
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
+]
+
+
+LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json",
+}
+
+
+""" LDMBERT model configuration"""
+
+
+class LDMBertConfig(PretrainedConfig):
+ model_type = "ldmbert"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=77,
+ encoder_layers=32,
+ encoder_ffn_dim=5120,
+ encoder_attention_heads=8,
+ head_dim=64,
+ encoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1280,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
+class LDMBertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ head_dim: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+ self.inner_dim = head_dim * num_heads
+
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class LDMBertEncoderLayer(nn.Module):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = LDMBertAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ head_dim=config.head_dim,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
+class LDMBertPreTrainedModel(PreTrainedModel):
+ config_class = LDMBertConfig
+ base_model_prefix = "model"
+ _supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LDMBertEncoder,)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class LDMBertEncoder(LDMBertPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`LDMBertEncoderLayer`].
+
+ Args:
+ config: LDMBertConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_len = input_shape[1]
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
+ embed_pos = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class LDMBertModel(LDMBertPreTrainedModel):
+ _no_split_modules = []
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+ self.model = LDMBertEncoder(config)
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae620d325307605fa08fa977b9865dfc9adff057
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -0,0 +1,159 @@
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.utils.checkpoint
+
+from ...models import UNet2DModel, VQModel
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class LDMSuperResolutionPipeline(DiffusionPipeline):
+ r"""
+ A pipeline for image super-resolution using Latent
+
+ This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
+ [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: VQModel,
+ unet: UNet2DModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ batch_size: Optional[int] = 1,
+ num_inference_steps: Optional[int] = 100,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, torch.Tensor):
+ batch_size = image.shape[0]
+ else:
+ raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ height, width = image.shape[-2:]
+
+ # in_channels should be 6: 3 for latents, 3 for low resolution image
+ latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
+ latents_dtype = next(self.unet.parameters()).dtype
+
+ latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+
+ image = image.to(device=self.device, dtype=latents_dtype)
+
+ # set timesteps and move to the correct device
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(timesteps_tensor):
+ # concat latents and low resolution image in the channel dimension.
+ latents_input = torch.cat([latents, image], dim=1)
+ latents_input = self.scheduler.scale_model_input(latents_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VQVAE
+ image = self.vqvae.decode(latents).sample
+ image = torch.clamp(image, -1.0, 1.0)
+ image = image / 2 + 0.5
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9fc5270a62bbb18d1393263101d4b9f73b7511
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
@@ -0,0 +1 @@
+from .pipeline_latent_diffusion_uncond import LDMPipeline
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1217a15beaedd74c3fa3d6918e7e082a15359791
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3a2aa4da26c8cb79661cd3a313a5e4fdec44a88
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..865be1756ab40c346bfaba993d4253ca0e8dee4a
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-310.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f53e4259755891e426adcc0de9eabea7ce713c5
Binary files /dev/null and b/diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-38.pyc differ
diff --git a/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c607a27187eb93a55570a825a4beee329a256c
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -0,0 +1,111 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel, VQModel
+from ...schedulers import DDIMScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class LDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
+ """
+
+ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ latents = randn_tensor(
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
+ # predict the noise residual
+ noise_prediction = self.unet(latent_model_input, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VAE
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/onnx_utils.py b/diffusers/pipelines/onnx_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c32e4e84bfee0241733a077fef9c0dec06905e
--- /dev/null
+++ b/diffusers/pipelines/onnx_utils.py
@@ -0,0 +1,212 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+from huggingface_hub import hf_hub_download
+
+from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
+
+
+if is_onnx_available():
+ import onnxruntime as ort
+
+
+logger = logging.get_logger(__name__)
+
+ORT_TO_NP_TYPE = {
+ "tensor(bool)": np.bool_,
+ "tensor(int8)": np.int8,
+ "tensor(uint8)": np.uint8,
+ "tensor(int16)": np.int16,
+ "tensor(uint16)": np.uint16,
+ "tensor(int32)": np.int32,
+ "tensor(uint32)": np.uint32,
+ "tensor(int64)": np.int64,
+ "tensor(uint64)": np.uint64,
+ "tensor(float16)": np.float16,
+ "tensor(float)": np.float32,
+ "tensor(double)": np.float64,
+}
+
+
+class OnnxRuntimeModel:
+ def __init__(self, model=None, **kwargs):
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
+ self.model = model
+ self.model_save_dir = kwargs.get("model_save_dir", None)
+ self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
+
+ def __call__(self, **kwargs):
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
+ return self.model.run(None, inputs)
+
+ @staticmethod
+ def load_model(path: Union[str, Path], provider=None, sess_options=None):
+ """
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
+
+ Arguments:
+ path (`str` or `Path`):
+ Directory from which to load
+ provider(`str`, *optional*):
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
+ """
+ if provider is None:
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
+ provider = "CPUExecutionProvider"
+
+ return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
+
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
+ latest_model_name.
+
+ Arguments:
+ save_directory (`str` or `Path`):
+ Directory where to save the model file.
+ file_name(`str`, *optional*):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
+ model with a different name.
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
+ dst_path = Path(save_directory).joinpath(model_file_name)
+ try:
+ shutil.copyfile(src_path, dst_path)
+ except shutil.SameFileError:
+ pass
+
+ # copy external weights (for models >2GB)
+ src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
+ if src_path.exists():
+ dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
+ try:
+ shutil.copyfile(src_path, dst_path)
+ except shutil.SameFileError:
+ pass
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ **kwargs,
+ ):
+ """
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
+ method.:
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # saving model weights/files
+ self._save_pretrained(save_directory, **kwargs)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ use_auth_token: Optional[Union[bool, str, None]] = None,
+ revision: Optional[Union[str, None]] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ file_name: Optional[str] = None,
+ provider: Optional[str] = None,
+ sess_options: Optional["ort.SessionOptions"] = None,
+ **kwargs,
+ ):
+ """
+ Load a model from a directory or the HF Hub.
+
+ Arguments:
+ model_id (`str` or `Path`):
+ Directory from which to load
+ use_auth_token (`str` or `bool`):
+ Is needed to load models from a private or gated repository
+ revision (`str`):
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
+ cache_dir (`Union[str, Path]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ file_name(`str`):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
+ different model files from the same repository or directory.
+ provider(`str`):
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
+ kwargs (`Dict`, *optional*):
+ kwargs will be passed to the model during initialization
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+ # load model from local directory
+ if os.path.isdir(model_id):
+ model = OnnxRuntimeModel.load_model(
+ os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
+ )
+ kwargs["model_save_dir"] = Path(model_id)
+ # load model from hub
+ else:
+ # download model
+ model_cache_path = hf_hub_download(
+ repo_id=model_id,
+ filename=model_file_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
+ kwargs["latest_model_name"] = Path(model_cache_path).name
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
+ return cls(model=model, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ force_download: bool = True,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ **model_kwargs,
+ ):
+ revision = None
+ if len(str(model_id).split("@")) == 2:
+ model_id, revision = model_id.split("@")
+
+ return cls._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ use_auth_token=use_auth_token,
+ **model_kwargs,
+ )
diff --git a/diffusers/pipelines/paint_by_example/__init__.py b/diffusers/pipelines/paint_by_example/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fc8cb71e3f4e1e8baf16c7143658ca64934306
--- /dev/null
+++ b/diffusers/pipelines/paint_by_example/__init__.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from ...utils import is_torch_available, is_transformers_available
+
+
+if is_transformers_available() and is_torch_available():
+ from .image_encoder import PaintByExampleImageEncoder
+ from .pipeline_paint_by_example import PaintByExamplePipeline
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eeffcf96e39a190fa9dab2909c06f6320357979c
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f6db4581be9a78ecad6dcfd633de745e765deca
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5304bfe2805f636cb4fa7bfe7641d9628524ef8
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-310.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..637ac265a63419ee946d5d364dc2de8ecdce915a
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-38.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c421b21dd0d17647a7ac44f6b5113bd9f9058ae
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-310.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24484935753f0629b253fa553c6e423902a879ae
Binary files /dev/null and b/diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-38.pyc differ
diff --git a/diffusers/pipelines/paint_by_example/image_encoder.py b/diffusers/pipelines/paint_by_example/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..831489eefed167264c8fd8f57e1ed59610ebb858
--- /dev/null
+++ b/diffusers/pipelines/paint_by_example/image_encoder.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from torch import nn
+from transformers import CLIPPreTrainedModel, CLIPVisionModel
+
+from ...models.attention import BasicTransformerBlock
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class PaintByExampleImageEncoder(CLIPPreTrainedModel):
+ def __init__(self, config, proj_size=768):
+ super().__init__(config)
+ self.proj_size = proj_size
+
+ self.model = CLIPVisionModel(config)
+ self.mapper = PaintByExampleMapper(config)
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+ self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
+
+ # uncondition for scaling
+ self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
+
+ def forward(self, pixel_values, return_uncond_vector=False):
+ clip_output = self.model(pixel_values=pixel_values)
+ latent_states = clip_output.pooler_output
+ latent_states = self.mapper(latent_states[:, None])
+ latent_states = self.final_layer_norm(latent_states)
+ latent_states = self.proj_out(latent_states)
+ if return_uncond_vector:
+ return latent_states, self.uncond_vector
+
+ return latent_states
+
+
+class PaintByExampleMapper(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ num_layers = (config.num_hidden_layers + 1) // 5
+ hid_size = config.hidden_size
+ num_heads = 1
+ self.blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(self, hidden_states):
+ for block in self.blocks:
+ hidden_states = block(hidden_states)
+
+ return hidden_states
diff --git a/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..f844834b527d4cffcffaa589df22c633e358074e
--- /dev/null
+++ b/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -0,0 +1,557 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from .image_encoder import PaintByExampleImageEncoder
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ """
+ Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Batched mask
+ if mask.shape[0] == image.shape[0]:
+ mask = mask.unsqueeze(1)
+ else:
+ mask = mask.unsqueeze(0)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+ assert mask.shape[1] == 1, "Mask image must have a single channel"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # paint-by-example inverses the mask
+ mask = 1 - mask
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, PIL.Image.Image):
+ mask = [mask]
+
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+
+ # paint-by-example inverses the mask
+ mask = 1 - mask
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * mask
+
+ return mask, masked_image
+
+
+class PaintByExamplePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ image_encoder ([`PaintByExampleImageEncoder`]):
+ Encodes the example input image. The unet is conditioned on the example image instead of a text prompt.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ # TODO: feature_extractor is required to encode initial images (if they are in PIL format),
+ # we should give a descriptive message if the pipeline doesn't have one.
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: PaintByExampleImageEncoder,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = False,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ example_image: Union[torch.FloatTensor, PIL.Image.Image],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
+ The exemplar image to guide the image generation.
+ image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 2. Preprocess mask and image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+ height, width = masked_image.shape[-2:]
+
+ # 3. Check inputs
+ self.check_inputs(example_image, height, width, callback_steps)
+
+ # 4. Encode input image
+ image_embeddings = self._encode_image(
+ example_image, device, num_images_per_prompt, do_classifier_free_guidance
+ )
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/pipeline_flax_utils.py b/diffusers/pipelines/pipeline_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c4b9f53953c1f18e3cc90088dfddd612cbfa63
--- /dev/null
+++ b/diffusers/pipelines/pipeline_flax_utils.py
@@ -0,0 +1,568 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from typing import Any, Dict, List, Optional, Union
+
+import flax
+import numpy as np
+import PIL
+from flax.core.frozen_dict import FrozenDict
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from ..configuration_utils import ConfigMixin
+from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
+from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
+from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging
+
+
+if is_transformers_available():
+ from transformers import FlaxPreTrainedModel
+
+INDEX_FILE = "diffusion_flax_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
+ "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
+ "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+def import_flax_or_no_model(module, class_name):
+ try:
+ # 1. First make sure that if a Flax object is present, import this one
+ class_obj = getattr(module, "Flax" + class_name)
+ except AttributeError:
+ # 2. If this doesn't work, it's not a model and we don't append "Flax"
+ class_obj = getattr(module, class_name)
+ except AttributeError:
+ raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
+
+ return class_obj
+
+
+@flax.struct.dataclass
+class FlaxImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class FlaxDiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
+ pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
+ pipelines to:
+
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ components of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ if module is None:
+ register_dict = {name: (None, None)}
+ else:
+ # retrieve library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrieve class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
+ # TODO: handle inference_state
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
+ method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ if sub_model is None:
+ # edge case for saving a pipeline with safety_checker=None
+ continue
+
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class, None)
+ if class_candidate is not None and issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
+
+ if expects_params:
+ save_method(
+ os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
+ )
+ else:
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ dtype (`str` or `jnp.dtype`, *optional*):
+ Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
+ `__init__` method. See example below for more information.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import FlaxDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> # Requires to be logged in to Hugging Face hub,
+ >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5",
+ ... revision="bf16",
+ ... dtype=jnp.bfloat16,
+ ... )
+
+ >>> # Download pipeline, but use a different scheduler
+ >>> from diffusers import FlaxDPMSolverMultistepScheduler
+
+ >>> model_id = "runwayml/stable-diffusion-v1-5"
+ >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ ... model_id,
+ ... subfolder="scheduler",
+ ... )
+
+ >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
+ ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
+ ... )
+ >>> dpm_params["scheduler"] = dpmpp_state
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_pt = kwargs.pop("from_pt", False)
+ use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
+ dtype = kwargs.pop("dtype", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ config_dict = cls.load_config(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ # make sure we only download sub-folders and `diffusers` filenames
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
+
+ # make sure we don't download PyTorch weights, unless when using from_pt
+ ignore_patterns = "*.bin" if not from_pt else []
+
+ if cls != FlaxDiffusionPipeline:
+ requested_pipeline_class = cls.__name__
+ else:
+ requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
+ requested_pipeline_class = (
+ requested_pipeline_class
+ if requested_pipeline_class.startswith("Flax")
+ else "Flax" + requested_pipeline_class
+ )
+
+ user_agent = {"pipeline_class": requested_pipeline_class}
+ user_agent = http_user_agent(user_agent)
+
+ # download all allow_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.load_config(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != FlaxDiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ class_name = (
+ config_dict["_class_name"]
+ if config_dict["_class_name"].startswith("Flax")
+ else "Flax" + config_dict["_class_name"]
+ )
+ pipeline_class = getattr(diffusers_module, class_name)
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # inference_params
+ params = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ if class_name is None:
+ # edge case for when the pipeline was saved with safety_checker=None
+ init_kwargs[name] = None
+ continue
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+ sub_model_should_be_defined = True
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ elif passed_class_obj[name] is None:
+ logger.warning(
+ f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
+ f" that this might lead to problems when using {pipeline_class} and is not recommended."
+ )
+ sub_model_should_be_defined = False
+ else:
+ logger.warning(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = import_flax_or_no_model(pipeline_module, class_name)
+
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = import_flax_or_no_model(library, class_name)
+
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None and sub_model_should_be_defined:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loadable_folder = os.path.join(cached_folder, name)
+ else:
+ loaded_sub_model = cached_folder
+
+ if issubclass(class_obj, FlaxModelMixin):
+ loaded_sub_model, loaded_params = load_method(
+ loadable_folder,
+ from_pt=from_pt,
+ use_memory_efficient_attention=use_memory_efficient_attention,
+ dtype=dtype,
+ )
+ params[name] = loaded_params
+ elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
+ if from_pt:
+ # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
+ loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
+ loaded_params = loaded_sub_model.params
+ del loaded_sub_model._params
+ else:
+ loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
+ params[name] = loaded_params
+ elif issubclass(class_obj, FlaxSchedulerMixin):
+ loaded_sub_model, scheduler_state = load_method(loadable_folder)
+ params[name] = scheduler_state
+ else:
+ loaded_sub_model = load_method(loadable_folder)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Potentially add passed objects if expected
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ passed_modules = list(passed_class_obj.keys())
+
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj.get(module, None)
+ elif len(missing_modules) > 0:
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ model = pipeline_class(**init_kwargs, dtype=dtype)
+ return model, params
+
+ @staticmethod
+ def _get_signature_keys(obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - {"self"}
+ return expected_modules, optional_parameters
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ r"""
+
+ The `self.components` property can be useful to run different pipelines with the same weights and
+ configurations to not have to re-allocate memory.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import (
+ ... FlaxStableDiffusionPipeline,
+ ... FlaxStableDiffusionImg2ImgPipeline,
+ ... )
+
+ >>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
+ ... )
+ >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
+ ```
+
+ Returns:
+ A dictionary containing all the modules needed to initialize the pipeline.
+ """
+ expected_modules, optional_parameters = self._get_signature_keys(self)
+ components = {
+ k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
+ }
+
+ if set(components.keys()) != expected_modules:
+ raise ValueError(
+ f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
+ f" {expected_modules} to be defined, but {components} are defined."
+ )
+
+ return components
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ # TODO: make it compatible with jax.lax
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/diffusers/pipelines/pipeline_utils.py b/diffusers/pipelines/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad52c6ac1c59ad1b7855f0208c1549123d2623fd
--- /dev/null
+++ b/diffusers/pipelines/pipeline_utils.py
@@ -0,0 +1,1566 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import fnmatch
+import importlib
+import inspect
+import os
+import re
+import sys
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from huggingface_hub import hf_hub_download, model_info, snapshot_download
+from packaging import version
+from requests.exceptions import HTTPError
+from tqdm.auto import tqdm
+
+import diffusers
+
+from .. import __version__
+from ..configuration_utils import ConfigMixin
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
+from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from ..utils import (
+ CONFIG_NAME,
+ DEPRECATED_REVISION_ARGS,
+ DIFFUSERS_CACHE,
+ HF_HUB_OFFLINE,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ BaseOutput,
+ deprecate,
+ get_class_from_dynamic_module,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_compiled_module,
+ is_safetensors_available,
+ is_torch_version,
+ is_transformers_available,
+ logging,
+ numpy_to_pil,
+)
+
+
+if is_transformers_available():
+ import transformers
+ from transformers import PreTrainedModel
+ from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
+ from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
+ from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
+
+from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
+
+
+if is_accelerate_available():
+ import accelerate
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
+DUMMY_MODULES_FOLDER = "diffusers.utils"
+TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_pretrained", "from_pretrained"],
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
+ },
+ "onnxruntime.training": {
+ "ORTModule": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+@dataclass
+class AudioPipelineOutput(BaseOutput):
+ """
+ Output class for audio pipelines.
+
+ Args:
+ audios (`np.ndarray`)
+ List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.
+ """
+
+ audios: np.ndarray
+
+
+def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
+ """
+ Checking for safetensors compatibility:
+ - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
+ files to know which safetensors files are needed.
+ - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
+
+ Converting default pytorch serialized filenames to safetensors serialized filenames:
+ - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
+ - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
+ extension is replaced with ".safetensors"
+ """
+ pt_filenames = []
+
+ sf_filenames = set()
+
+ passed_components = passed_components or []
+
+ for filename in filenames:
+ _, extension = os.path.splitext(filename)
+
+ if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
+ continue
+
+ if extension == ".bin":
+ pt_filenames.append(filename)
+ elif extension == ".safetensors":
+ sf_filenames.add(filename)
+
+ for filename in pt_filenames:
+ # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
+ path, filename = os.path.split(filename)
+ filename, extension = os.path.splitext(filename)
+
+ if filename.startswith("pytorch_model"):
+ filename = filename.replace("pytorch_model", "model")
+ else:
+ filename = filename
+
+ expected_sf_filename = os.path.join(path, filename)
+ expected_sf_filename = f"{expected_sf_filename}.safetensors"
+
+ if expected_sf_filename not in sf_filenames:
+ logger.warning(f"{expected_sf_filename} not found")
+ return False
+
+ return True
+
+
+def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
+ weight_names = [
+ WEIGHTS_NAME,
+ SAFETENSORS_WEIGHTS_NAME,
+ FLAX_WEIGHTS_NAME,
+ ONNX_WEIGHTS_NAME,
+ ONNX_EXTERNAL_WEIGHTS_NAME,
+ ]
+
+ if is_transformers_available():
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+
+ # model_pytorch, diffusion_model_pytorch, ...
+ weight_prefixes = [w.split(".")[0] for w in weight_names]
+ # .bin, .safetensors, ...
+ weight_suffixs = [w.split(".")[-1] for w in weight_names]
+ # -00001-of-00002
+ transformers_index_format = r"\d{5}-of-\d{5}"
+
+ if variant is not None:
+ # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
+ variant_file_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
+ )
+ # `text_encoder/pytorch_model.bin.index.fp16.json`
+ variant_index_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
+ )
+
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
+ non_variant_file_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
+ )
+ # `text_encoder/pytorch_model.bin.index.json`
+ non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
+
+ if variant is not None:
+ variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
+ variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
+ variant_filenames = variant_weights | variant_indexes
+ else:
+ variant_filenames = set()
+
+ non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
+ non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
+ non_variant_filenames = non_variant_weights | non_variant_indexes
+
+ # all variant filenames will be used by default
+ usable_filenames = set(variant_filenames)
+
+ def convert_to_variant(filename):
+ if "index" in filename:
+ variant_filename = filename.replace("index", f"index.{variant}")
+ elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
+ variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
+ else:
+ variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
+ return variant_filename
+
+ for f in non_variant_filenames:
+ variant_filename = convert_to_variant(f)
+ if variant_filename not in usable_filenames:
+ usable_filenames.add(f)
+
+ return usable_filenames, variant_filenames
+
+
+def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames):
+ info = model_info(
+ pretrained_model_name_or_path,
+ use_auth_token=use_auth_token,
+ revision=None,
+ )
+ filenames = {sibling.rfilename for sibling in info.siblings}
+ comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
+ comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
+
+ if set(comp_model_filenames) == set(model_filenames):
+ warnings.warn(
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
+ FutureWarning,
+ )
+ else:
+ warnings.warn(
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
+ FutureWarning,
+ )
+
+
+def maybe_raise_or_warn(
+ library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
+):
+ """Simple helper method to raise or warn in case incorrect module has been passed"""
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ # Dynamo wraps the original model in a private class.
+ # I didn't find a public API to get the original class.
+ sub_model = passed_class_obj[name]
+ model_cls = sub_model.__class__
+ if is_compiled_module(sub_model):
+ model_cls = sub_model._orig_mod.__class__
+
+ if not issubclass(model_cls, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
+ )
+ else:
+ logger.warning(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+
+def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
+ """Simple helper method to retrieve class object of module as well as potential parent class objects"""
+ if is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+
+ class_obj = getattr(pipeline_module, class_name)
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+
+ class_obj = getattr(library, class_name)
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ return class_obj, class_candidates
+
+
+def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None):
+ if custom_pipeline is not None:
+ if custom_pipeline.endswith(".py"):
+ path = Path(custom_pipeline)
+ # decompose into folder & file
+ file_name = path.name
+ custom_pipeline = path.parent.absolute()
+ else:
+ file_name = CUSTOM_PIPELINE_FILE_NAME
+
+ return get_class_from_dynamic_module(
+ custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
+ )
+
+ if class_obj != DiffusionPipeline:
+ return class_obj
+
+ diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
+ return getattr(diffusers_module, config["_class_name"])
+
+
+def load_sub_model(
+ library_name: str,
+ class_name: str,
+ importable_classes: List[Any],
+ pipelines: Any,
+ is_pipeline_module: bool,
+ pipeline_class: Any,
+ torch_dtype: torch.dtype,
+ provider: Any,
+ sess_options: Any,
+ device_map: Optional[Union[Dict[str, torch.device], str]],
+ max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
+ offload_folder: Optional[Union[str, os.PathLike]],
+ offload_state_dict: bool,
+ model_variants: Dict[str, str],
+ name: str,
+ from_flax: bool,
+ variant: str,
+ low_cpu_mem_usage: bool,
+ cached_folder: Union[str, os.PathLike],
+):
+ """Helper method to load the module `name` from `library_name` and `class_name`"""
+ # retrieve class candidates
+ class_obj, class_candidates = get_class_obj_and_candidates(
+ library_name, class_name, importable_classes, pipelines, is_pipeline_module
+ )
+
+ load_method_name = None
+ # retrive load method name
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ # if load method name is None, then we have a dummy module -> raise Error
+ if load_method_name is None:
+ none_module = class_obj.__module__
+ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
+ TRANSFORMERS_DUMMY_MODULES_FOLDER
+ )
+ if is_dummy_path and "dummy" in none_module:
+ # call class_obj for nice error message of missing requirements
+ class_obj()
+
+ raise ValueError(
+ f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
+ f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
+ )
+
+ load_method = getattr(class_obj, load_method_name)
+
+ # add kwargs to loading method
+ loading_kwargs = {}
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+ loading_kwargs["sess_options"] = sess_options
+
+ is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
+
+ if is_transformers_available():
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
+ else:
+ transformers_version = "N/A"
+
+ is_transformers_model = (
+ is_transformers_available()
+ and issubclass(class_obj, PreTrainedModel)
+ and transformers_version >= version.parse("4.20.0")
+ )
+
+ # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
+ # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
+ # This makes sure that the weights won't be initialized which significantly speeds up loading.
+ if is_diffusers_model or is_transformers_model:
+ loading_kwargs["device_map"] = device_map
+ loading_kwargs["max_memory"] = max_memory
+ loading_kwargs["offload_folder"] = offload_folder
+ loading_kwargs["offload_state_dict"] = offload_state_dict
+ loading_kwargs["variant"] = model_variants.pop(name, None)
+ if from_flax:
+ loading_kwargs["from_flax"] = True
+
+ # the following can be deleted once the minimum required `transformers` version
+ # is higher than 4.27
+ if (
+ is_transformers_model
+ and loading_kwargs["variant"] is not None
+ and transformers_version < version.parse("4.27.0")
+ ):
+ raise ImportError(
+ f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
+ )
+ elif is_transformers_model and loading_kwargs["variant"] is None:
+ loading_kwargs.pop("variant")
+
+ # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
+ if not (from_flax and is_transformers_model):
+ loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
+ else:
+ loading_kwargs["low_cpu_mem_usage"] = False
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ return loaded_sub_model
+
+
+class DiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all pipelines.
+
+ [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and
+ provides methods for loading, downloading and saving models. It also includes methods to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+ - **_optional_components** (List[`str`]) -- List of all optional components that don't have to be passed to the
+ pipeline to function (should be overridden by subclasses).
+ """
+ config_name = "model_index.json"
+ _optional_components = []
+ _exclude_from_cpu_offload = []
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrieve library
+ if module is None:
+ register_dict = {name: (None, None)}
+ else:
+ # register the config from the original module, not the dynamo compiled one
+ if is_compiled_module(module):
+ not_compiled_module = module._orig_mod
+ else:
+ not_compiled_module = module
+
+ library = not_compiled_module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ module_path_items = not_compiled_module.__module__.split(".")
+ pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
+
+ path = not_compiled_module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if is_pipeline_module:
+ library = pipeline_dir
+ elif library not in LOADABLE_CLASSES:
+ library = not_compiled_module.__module__
+
+ # retrieve class_name
+ class_name = not_compiled_module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def __setattr__(self, name: str, value: Any):
+ if name in self.__dict__ and hasattr(self.config, name):
+ # We need to overwrite the config if name exists in config
+ if isinstance(getattr(self.config, name), (tuple, list)):
+ if value is not None and self.config[name][0] is not None:
+ class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
+ else:
+ class_library_tuple = (None, None)
+
+ self.register_to_config(**{name: class_library_tuple})
+ else:
+ self.register_to_config(**{name: value})
+
+ super().__setattr__(name, value)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ safe_serialization: bool = False,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
+ class implements both a save and loading method. The pipeline is easily reloaded using the
+ [`~DiffusionPipeline.from_pretrained`] class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save a pipeline to. Will be created if it doesn't exist.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format `pytorch_model..bin`.
+ """
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name", None)
+ model_index_dict.pop("_diffusers_version", None)
+ model_index_dict.pop("_module", None)
+ model_index_dict.pop("_name_or_path", None)
+
+ expected_modules, optional_kwargs = self._get_signature_keys(self)
+
+ def is_saveable_module(name, value):
+ if name not in expected_modules:
+ return False
+ if name in self._optional_components and value[0] is None:
+ return False
+ return True
+
+ model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ # Dynamo wraps the original model in a private class.
+ # I didn't find a public API to get the original class.
+ if is_compiled_module(sub_model):
+ sub_model = sub_model._orig_mod
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ if library_name in sys.modules:
+ library = importlib.import_module(library_name)
+ else:
+ logger.info(
+ f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
+ )
+
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class, None)
+ if class_candidate is not None and issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ if save_method_name is None:
+ logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
+ # make sure that unsaveable components are not tried to be loaded afterward
+ self.register_to_config(**{pipeline_component_name: (None, None)})
+ continue
+
+ save_method = getattr(sub_model, save_method_name)
+
+ # Call the save method with the argument safe_serialization only if it's supported
+ save_method_signature = inspect.signature(save_method)
+ save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
+ save_method_accept_variant = "variant" in save_method_signature.parameters
+
+ save_kwargs = {}
+ if save_method_accept_safe:
+ save_kwargs["safe_serialization"] = safe_serialization
+ if save_method_accept_variant:
+ save_kwargs["variant"] = variant
+
+ save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
+
+ # finally save the config
+ self.save_config(save_directory)
+
+ def to(
+ self,
+ torch_device: Optional[Union[str, torch.device]] = None,
+ torch_dtype: Optional[torch.dtype] = None,
+ silence_dtype_warnings: bool = False,
+ ):
+ if torch_device is None and torch_dtype is None:
+ return self
+
+ # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
+ def module_is_sequentially_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
+ return False
+
+ return hasattr(module, "_hf_hook") and not isinstance(
+ module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
+ )
+
+ def module_is_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
+ return False
+
+ return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
+
+ # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
+ pipeline_is_sequentially_offloaded = any(
+ module_is_sequentially_offloaded(module) for _, module in self.components.items()
+ )
+ if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
+ raise ValueError(
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
+ )
+
+ # Display a warning in this case (the operation succeeds but the benefits are lost)
+ pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
+ if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
+ logger.warning(
+ f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
+ )
+
+ module_names, _ = self._get_signature_keys(self)
+ modules = [getattr(self, n, None) for n in module_names]
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
+ for module in modules:
+ is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
+
+ if is_loaded_in_8bit and torch_dtype is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
+ )
+
+ if is_loaded_in_8bit and torch_device is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
+ )
+ else:
+ module.to(torch_device, torch_dtype)
+
+ if (
+ module.dtype == torch.float16
+ and str(torch_device) in ["cpu"]
+ and not silence_dtype_warnings
+ and not is_offloaded
+ ):
+ logger.warning(
+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
+ " is not recommended to move them to `cpu` as running them will fail. Please make"
+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
+ " support for`float16` operations on this device in PyTorch. Please, remove the"
+ " `torch_dtype=torch.float16` argument, or use another device for inference."
+ )
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _ = self._get_signature_keys(self)
+ modules = [getattr(self, n, None) for n in module_names]
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ return module.device
+
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
+
+ The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ custom_pipeline (`str`, *optional*):
+
+
+
+ 🧪 This is an experimental feature and may change in the future.
+
+
+
+ Can be either:
+
+ - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom
+ pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines
+ the custom pipeline.
+ - A string, the *file name* of a community pipeline hosted on GitHub under
+ [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
+ names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
+ instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
+ current main branch of GitHub.
+ - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory
+ must contain a file called `pipeline.py` that defines the custom pipeline.
+
+
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
+ Adding Custom
+ Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
+
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+
+ >>> # Use a different scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.scheduler = scheduler
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_flax = kwargs.pop("from_flax", False)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
+ custom_revision = kwargs.pop("custom_revision", None)
+ provider = kwargs.pop("provider", None)
+ sess_options = kwargs.pop("sess_options", None)
+ device_map = kwargs.pop("device_map", None)
+ max_memory = kwargs.pop("max_memory", None)
+ offload_folder = kwargs.pop("offload_folder", None)
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ cached_folder = cls.download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ from_flax=from_flax,
+ use_safetensors=use_safetensors,
+ custom_pipeline=custom_pipeline,
+ custom_revision=custom_revision,
+ variant=variant,
+ **kwargs,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.load_config(cached_folder)
+
+ # pop out "_ignore_files" as it is only needed for download
+ config_dict.pop("_ignore_files", None)
+
+ # 2. Define which model components should load variants
+ # We retrieve the information by matching whether variant
+ # model checkpoints exist in the subfolders
+ model_variants = {}
+ if variant is not None:
+ for folder in os.listdir(cached_folder):
+ folder_path = os.path.join(cached_folder, folder)
+ is_folder = os.path.isdir(folder_path) and folder in config_dict
+ variant_exists = is_folder and any(
+ p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
+ )
+ if variant_exists:
+ model_variants[folder] = variant
+
+ # 3. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ pipeline_class = _get_pipeline_class(
+ cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
+ )
+
+ # DEPRECATED: To be removed in 1.0.0
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
+ version.parse(config_dict["_diffusers_version"]).base_version
+ ) <= version.parse("0.5.1"):
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
+
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
+
+ deprecation_message = (
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
+ )
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
+
+ # 4. Define expected modules given pipeline signature
+ # and define non-None initialized modules (=`init_kwargs`)
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ # define init kwargs
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
+
+ # remove `null` components
+ def load_module(name, value):
+ if value[0] is None:
+ return False
+ if name in passed_class_obj and passed_class_obj[name] is None:
+ return False
+ return True
+
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
+
+ # Special case: safety_checker must be loaded separately when using `from_flax`
+ if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
+ raise NotImplementedError(
+ "The safety checker cannot be automatically loaded when loading weights `from_flax`."
+ " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
+ " separately if you need it."
+ )
+
+ # 5. Throw nice warnings / errors for fast accelerate loading
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
+ )
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 6. Load each module in the pipeline
+ for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."):
+ # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
+ if class_name.startswith("Flax"):
+ class_name = class_name[4:]
+
+ # 6.2 Define all importable classes
+ is_pipeline_module = hasattr(pipelines, library_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ loaded_sub_model = None
+
+ # 6.3 Use passed sub model or load class_name from library_name
+ if name in passed_class_obj:
+ # if the model is in a pipeline module, then we load it from the pipeline
+ # check that passed_class_obj has correct parent class
+ maybe_raise_or_warn(
+ library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
+ )
+
+ loaded_sub_model = passed_class_obj[name]
+ else:
+ # load sub model
+ loaded_sub_model = load_sub_model(
+ library_name=library_name,
+ class_name=class_name,
+ importable_classes=importable_classes,
+ pipelines=pipelines,
+ is_pipeline_module=is_pipeline_module,
+ pipeline_class=pipeline_class,
+ torch_dtype=torch_dtype,
+ provider=provider,
+ sess_options=sess_options,
+ device_map=device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ model_variants=model_variants,
+ name=name,
+ from_flax=from_flax,
+ variant=variant,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ cached_folder=cached_folder,
+ )
+ logger.info(
+ f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
+ )
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 7. Potentially add passed objects if expected
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ passed_modules = list(passed_class_obj.keys())
+ optional_modules = pipeline_class._optional_components
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj.get(module, None)
+ elif len(missing_modules) > 0:
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ # 8. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+
+ # 9. Save where the model was instantiated from
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ return model
+
+ @property
+ def name_or_path(self) -> str:
+ return getattr(self.config, "_name_or_path", None)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
+ Accelerate's module hooks.
+ """
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
+ continue
+
+ if not hasattr(model, "_hf_hook"):
+ return self.device
+ for module in model.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ if device == "cuda":
+ device = torch.device(f"{device}:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+
+ if name in self._exclude_from_cpu_offload:
+ model.to(device)
+ else:
+ # make sure to offload buffers if not all high level weights
+ # are of type nn.Module
+ offload_buffers = len(model._parameters) > 0
+ cpu_offload(model, device, offload_buffers=offload_buffers)
+
+ @classmethod
+ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
+ r"""
+ Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
+
+ Parameters:
+ pretrained_model_name (`str` or `os.PathLike`, *optional*):
+ A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ custom_pipeline (`str`, *optional*):
+ Can be either:
+
+ - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained
+ pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines
+ the custom pipeline.
+
+ - A string, the *file name* of a community pipeline hosted on GitHub under
+ [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
+ names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
+ instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
+ current `main` branch of GitHub.
+
+ - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
+ must contain a file called `pipeline.py` that defines the custom pipeline.
+
+
+
+ 🧪 This is an experimental feature and may change in the future.
+
+
+
+ For more information on how to load and create custom pipelines, take a look at [How to contribute a
+ community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
+
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+ Returns:
+ `os.PathLike`:
+ A path to the downloaded pipeline.
+
+
+
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
+ `huggingface-cli login`.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_flax = kwargs.pop("from_flax", False)
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
+ custom_revision = kwargs.pop("custom_revision", None)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ if use_safetensors and not is_safetensors_available():
+ raise ValueError(
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = is_safetensors_available()
+ allow_pickle = True
+
+ pipeline_is_cached = False
+ allow_patterns = None
+ ignore_patterns = None
+
+ if not local_files_only:
+ try:
+ info = model_info(
+ pretrained_model_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ except HTTPError as e:
+ logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
+ local_files_only = True
+
+ if not local_files_only:
+ config_file = hf_hub_download(
+ pretrained_model_name,
+ cls.config_name,
+ cache_dir=cache_dir,
+ revision=revision,
+ proxies=proxies,
+ force_download=force_download,
+ resume_download=resume_download,
+ use_auth_token=use_auth_token,
+ )
+
+ config_dict = cls._dict_from_json_file(config_file)
+
+ ignore_filenames = config_dict.pop("_ignore_files", [])
+
+ # retrieve all folder_names that contain relevant files
+ folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
+
+ filenames = {sibling.rfilename for sibling in info.siblings}
+ model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
+
+ if len(variant_filenames) == 0 and variant is not None:
+ deprecation_message = (
+ f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
+ f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
+ "if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
+ "modeling files is deprecated."
+ )
+ deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
+
+ # remove ignored filenames
+ model_filenames = set(model_filenames) - set(ignore_filenames)
+ variant_filenames = set(variant_filenames) - set(ignore_filenames)
+
+ # if the whole pipeline is cached we don't have to ping the Hub
+ if revision in DEPRECATED_REVISION_ARGS and version.parse(
+ version.parse(__version__).base_version
+ ) >= version.parse("0.20.0"):
+ warn_deprecated_model_variant(
+ pretrained_model_name, use_auth_token, variant, revision, model_filenames
+ )
+
+ model_folder_names = {os.path.split(f)[0] for f in model_filenames}
+
+ # all filenames compatible with variant will be added
+ allow_patterns = list(model_filenames)
+
+ # allow all patterns from non-model folders
+ # this enables downloading schedulers, tokenizers, ...
+ allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
+ # also allow downloading config.json files with the model
+ allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
+
+ allow_patterns += [
+ SCHEDULER_CONFIG_NAME,
+ CONFIG_NAME,
+ cls.config_name,
+ CUSTOM_PIPELINE_FILE_NAME,
+ ]
+
+ # retrieve passed components that should not be downloaded
+ pipeline_class = _get_pipeline_class(
+ cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
+ )
+ expected_components, _ = cls._get_signature_keys(pipeline_class)
+ passed_components = [k for k in expected_components if k in kwargs]
+
+ if (
+ use_safetensors
+ and not allow_pickle
+ and not is_safetensors_compatible(
+ model_filenames, variant=variant, passed_components=passed_components
+ )
+ ):
+ raise EnvironmentError(
+ f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ )
+ if from_flax:
+ ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
+ elif use_safetensors and is_safetensors_compatible(
+ model_filenames, variant=variant, passed_components=passed_components
+ ):
+ ignore_patterns = ["*.bin", "*.msgpack"]
+
+ safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
+ safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
+ if (
+ len(safetensors_variant_filenames) > 0
+ and safetensors_model_filenames != safetensors_variant_filenames
+ ):
+ logger.warn(
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
+ )
+ else:
+ ignore_patterns = ["*.safetensors", "*.msgpack"]
+
+ bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
+ bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
+ if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
+ logger.warn(
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
+ )
+
+ # Don't download any objects that are passed
+ allow_patterns = [
+ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
+ ]
+ # Don't download index files of forbidden patterns either
+ ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
+
+ re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
+ re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
+
+ expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
+ expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
+
+ snapshot_folder = Path(config_file).parent
+ pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
+
+ if pipeline_is_cached and not force_download:
+ # if the pipeline is cached, we can directly return it
+ # else call snapshot_download
+ return snapshot_folder
+
+ user_agent = {"pipeline_class": cls.__name__}
+ if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
+ user_agent["custom_pipeline"] = custom_pipeline
+
+ # download all allow_patterns - ignore_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
+ )
+
+ return cached_folder
+
+ @staticmethod
+ def _get_signature_keys(obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - {"self"}
+ return expected_modules, optional_parameters
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ r"""
+ The `self.components` property can be useful to run different pipelines with the same weights and
+ configurations without reallocating additional memory.
+
+ Returns (`dict`):
+ A dictionary containing all the modules needed to initialize the pipeline.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import (
+ ... StableDiffusionPipeline,
+ ... StableDiffusionImg2ImgPipeline,
+ ... StableDiffusionInpaintPipeline,
+ ... )
+
+ >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
+ >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
+ ```
+ """
+ expected_modules, optional_parameters = self._get_signature_keys(self)
+ components = {
+ k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
+ }
+
+ if set(components.keys()) != expected_modules:
+ raise ValueError(
+ f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
+ f" {expected_modules} to be defined, but {components.keys()} are defined."
+ )
+
+ return components
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a NumPy image or a batch of images to a PIL image.
+ """
+ return numpy_to_pil(images)
+
+ def progress_bar(self, iterable=None, total=None):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ if iterable is not None:
+ return tqdm(iterable, **self._progress_bar_config)
+ elif total is not None:
+ return tqdm(total=total, **self._progress_bar_config)
+ else:
+ raise ValueError("Either `total` or `iterable` has to be defined.")
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
+
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+ r"""
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
+ inference. Speed up during training is not guaranteed.
+
+
+
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
+ precedent.
+
+
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+ >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ >>> # Workaround for not accepting attention shape using VAE for Flash Attention
+ >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ module_names, _ = self._get_signature_keys(self)
+ modules = [getattr(self, n, None) for n in module_names]
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ fn_recursive_set_mem_eff(module)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ self.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is
+ computed in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def set_attention_slice(self, slice_size: Optional[int]):
+ module_names, _ = self._get_signature_keys(self)
+ modules = [getattr(self, n, None) for n in module_names]
+ modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
+
+ for module in modules:
+ module.set_attention_slice(slice_size)
diff --git a/diffusers/pipelines/pndm/__init__.py b/diffusers/pipelines/pndm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..488eb4f5f2b29c071fdc044ef282bc2838148c1e
--- /dev/null
+++ b/diffusers/pipelines/pndm/__init__.py
@@ -0,0 +1 @@
+from .pipeline_pndm import PNDMPipeline
diff --git a/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eecc3a7d6435cca7561a6ff05c97142fb0769593
Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72c740b00b74c5228132938a010c4ed6a3d30d2d
Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df3cd0ece13264032940cd172d4b7df3728c7552
Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-310.pyc differ
diff --git a/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4625b6034a7f31c6dce792c4d0be8c2aac13a419
Binary files /dev/null and b/diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-38.pyc differ
diff --git a/diffusers/pipelines/pndm/pipeline_pndm.py b/diffusers/pipelines/pndm/pipeline_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..361444079311ad87eb53fc41f02643c4f4bf3c93
--- /dev/null
+++ b/diffusers/pipelines/pndm/pipeline_pndm.py
@@ -0,0 +1,99 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...schedulers import PNDMScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class PNDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ unet: UNet2DModel
+ scheduler: PNDMScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
+ super().__init__()
+
+ scheduler = PNDMScheduler.from_config(scheduler.config)
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
+ num_inference_steps (`int`, `optional`, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator`, `optional`): A [torch
+ generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
+ between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
+ [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # For more information on the sampling method you can take a look at Algorithm 2 of
+ # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+
+ # Sample gaussian noise to begin loop
+ image = randn_tensor(
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+ generator=generator,
+ device=self.device,
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ model_output = self.unet(image, t).sample
+
+ image = self.scheduler.step(model_output, t, image).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/repaint/__init__.py b/diffusers/pipelines/repaint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bc86d1cedf6243fb92f7ba331b5a6188133298
--- /dev/null
+++ b/diffusers/pipelines/repaint/__init__.py
@@ -0,0 +1 @@
+from .pipeline_repaint import RePaintPipeline
diff --git a/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bddf43c6760aeb888c6fe703a561ef126862f5a
Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2157b1f6bde68443aee4bd152f240edcf6a28f90
Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..710eb6c5ceed6eff8ab7549474d59ae0f3bb4aff
Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-310.pyc differ
diff --git a/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9868d7136f5640556b5e1f197bf4e48220420e8c
Binary files /dev/null and b/diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-38.pyc differ
diff --git a/diffusers/pipelines/repaint/pipeline_repaint.py b/diffusers/pipelines/repaint/pipeline_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6527a023a74f080f17f0943c7cfd9e446d64d3e2
--- /dev/null
+++ b/diffusers/pipelines/repaint/pipeline_repaint.py
@@ -0,0 +1,177 @@
+# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...models import UNet2DModel
+from ...schedulers import RePaintScheduler
+from ...utils import PIL_INTERPOLATION, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
+ if isinstance(mask, torch.Tensor):
+ return mask
+ elif isinstance(mask, PIL.Image.Image):
+ mask = [mask]
+
+ if isinstance(mask[0], PIL.Image.Image):
+ w, h = mask[0].size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
+ mask = np.concatenate(mask, axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+ elif isinstance(mask[0], torch.Tensor):
+ mask = torch.cat(mask, dim=0)
+ return mask
+
+
+class RePaintPipeline(DiffusionPipeline):
+ unet: UNet2DModel
+ scheduler: RePaintScheduler
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image],
+ mask_image: Union[torch.Tensor, PIL.Image.Image],
+ num_inference_steps: int = 250,
+ eta: float = 0.0,
+ jump_length: int = 10,
+ jump_n_sample: int = 10,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The original image to inpaint on.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The mask_image where 0.0 values define which part of the original image to inpaint (change).
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`):
+ The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
+ and 1.0 is DDPM scheduler respectively.
+ jump_length (`int`, *optional*, defaults to 10):
+ The number of steps taken forward in time before going backward in time for a single jump ("j" in
+ RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ jump_n_sample (`int`, *optional*, defaults to 10):
+ The number of times we will make forward time jump for a given chosen time sample. Take a look at
+ Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ original_image = image
+
+ original_image = _preprocess_image(original_image)
+ original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype)
+ mask_image = _preprocess_mask(mask_image)
+ mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype)
+
+ batch_size = original_image.shape[0]
+
+ # sample gaussian noise to begin the loop
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ image_shape = original_image.shape
+ image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device)
+ self.scheduler.eta = eta
+
+ t_last = self.scheduler.timesteps[0] + 1
+ generator = generator[0] if isinstance(generator, list) else generator
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ if t < t_last:
+ # predict the noise residual
+ model_output = self.unet(image, t).sample
+ # compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample
+
+ else:
+ # compute the reverse: x_t-1 -> x_t
+ image = self.scheduler.undo_step(image, t_last, generator)
+ t_last = t
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/score_sde_ve/__init__.py b/diffusers/pipelines/score_sde_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c2a85c067b707c155e78a3c8b84562999134e7
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/__init__.py
@@ -0,0 +1 @@
+from .pipeline_score_sde_ve import ScoreSdeVePipeline
diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8efda4baeb6ee8efd7408452fc9fcc8c90eb1189
Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1785560c9ce94a4305d84023e145bc25b2671450
Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..789757be70f99b234225666e72571e09ce9cefdb
Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-310.pyc differ
diff --git a/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3805b5a27912b4fcbc5952bd72d2e1d2449f7910
Binary files /dev/null and b/diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-38.pyc differ
diff --git a/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff7b8ee460b58f98c4bd767f70946dc4da2a893
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...schedulers import ScoreSdeVeScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class ScoreSdeVePipeline(DiffusionPipeline):
+ r"""
+ Parameters:
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]):
+ The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image.
+ """
+ unet: UNet2DModel
+ scheduler: ScoreSdeVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 2000,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ self.scheduler.set_sigmas(num_inference_steps)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
+
+ # correction step
+ for _ in range(self.scheduler.config.correct_steps):
+ model_output = self.unet(sample, sigma_t).sample
+ sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
+
+ # prediction step
+ model_output = model(sample, sigma_t).sample
+ output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
+
+ sample, sample_mean = output.prev_sample, output.prev_sample_mean
+
+ sample = sample_mean.clamp(0, 1)
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ sample = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return ImagePipelineOutput(images=sample)
diff --git a/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/diffusers/pipelines/semantic_stable_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e312c5e30138e106930421ad8c55c23f01e60e7
--- /dev/null
+++ b/diffusers/pipelines/semantic_stable_diffusion/__init__.py
@@ -0,0 +1,31 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+class SemanticStableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96ca9446d7e783e2ce9eb7bd761e83e795691ca1
Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e932685eab8d3751a447cf670722a421d2f8347b
Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54061b7e030cd037550c42216559ad66f4b94a68
Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e164125bb2fb7ff83dd49d33765627959a11ff6b
Binary files /dev/null and b/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..911a5018de18de505323420f4220551d2b4f8624
--- /dev/null
+++ b/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -0,0 +1,724 @@
+import inspect
+import warnings
+from itertools import repeat
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import SemanticStableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SemanticStableDiffusionPipeline
+
+ >>> pipe = SemanticStableDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> out = pipe(
+ ... prompt="a photo of the face of a woman",
+ ... num_images_per_prompt=1,
+ ... guidance_scale=7,
+ ... editing_prompt=[
+ ... "smiling, smile", # Concepts to apply
+ ... "glasses, wearing glasses",
+ ... "curls, wavy hair, curly hair",
+ ... "beard, full beard, mustache",
+ ... ],
+ ... reverse_editing_direction=[
+ ... False,
+ ... False,
+ ... False,
+ ... False,
+ ... ], # Direction of guidance i.e. increase all concepts
+ ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept
+ ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept
+ ... edit_threshold=[
+ ... 0.99,
+ ... 0.975,
+ ... 0.925,
+ ... 0.96,
+ ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions
+ ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance
+ ... edit_mom_beta=0.6, # Momentum beta
+ ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other
+ ... )
+ >>> image = out.images[0]
+ ```
+"""
+
+
+class SemanticStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation with latent editing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ This model builds on the implementation of ['StableDiffusionPipeline']
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`Q16SafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ editing_prompt: Optional[Union[str, List[str]]] = None,
+ editing_prompt_embeddings: Optional[torch.Tensor] = None,
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
+ edit_momentum_scale: Optional[float] = 0.1,
+ edit_mom_beta: Optional[float] = 0.4,
+ edit_weights: Optional[List[float]] = None,
+ sem_guidance: Optional[List[torch.Tensor]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ editing_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting
+ `editing_prompt = None`. Guidance direction of prompt should be specified via
+ `reverse_editing_direction`.
+ editing_prompt_embeddings (`torch.Tensor>`, *optional*):
+ Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be
+ specified via `reverse_editing_direction`.
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
+ Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
+ Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`.
+ `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
+ edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
+ Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum
+ will still be calculated for those steps and applied once all warmup periods are over.
+ `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf).
+ edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
+ Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied.
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
+ Threshold of semantic guidance.
+ edit_momentum_scale (`float`, *optional*, defaults to 0.1):
+ Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0
+ momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are
+ finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
+ edit_mom_beta (`float`, *optional*, defaults to 0.4):
+ Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous
+ momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
+ edit_weights (`List[float]`, *optional*, defaults to `None`):
+ Indicates how much each individual concept should influence the overall guidance. If no weights are
+ provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
+ sem_guidance (`List[torch.Tensor]`, *optional*):
+ List of pre-generated guidance vectors to be applied at generation. Length of the list has to
+ correspond to `num_inference_steps`.
+
+ Returns:
+ [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True,
+ otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the
+ second element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+
+ if editing_prompt:
+ enable_edit_guidance = True
+ if isinstance(editing_prompt, str):
+ editing_prompt = [editing_prompt]
+ enabled_editing_prompts = len(editing_prompt)
+ elif editing_prompt_embeddings is not None:
+ enable_edit_guidance = True
+ enabled_editing_prompts = editing_prompt_embeddings.shape[0]
+ else:
+ enabled_editing_prompts = 0
+ enable_edit_guidance = False
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if enable_edit_guidance:
+ # get safety text embeddings
+ if editing_prompt_embeddings is None:
+ edit_concepts_input = self.tokenizer(
+ [x for item in editing_prompt for x in repeat(item, batch_size)],
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+
+ edit_concepts_input_ids = edit_concepts_input.input_ids
+
+ if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(
+ edit_concepts_input_ids[:, self.tokenizer.model_max_length :]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
+ edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
+ else:
+ edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
+ edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1)
+ edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if enable_edit_guidance:
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
+ else:
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ # get the initial random noise unless the user supplied it
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ self.device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Initialize edit_momentum to None
+ edit_momentum = None
+
+ self.uncond_estimates = None
+ self.text_estimates = None
+ self.edit_estimates = None
+ self.sem_guidance = None
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64]
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
+ noise_pred_edit_concepts = noise_pred_out[2:]
+
+ # default text guidance
+ noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0])
+
+ if self.uncond_estimates is None:
+ self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape))
+ self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
+
+ if self.text_estimates is None:
+ self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
+ self.text_estimates[i] = noise_pred_text.detach().cpu()
+
+ if self.edit_estimates is None and enable_edit_guidance:
+ self.edit_estimates = torch.zeros(
+ (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
+ )
+
+ if self.sem_guidance is None:
+ self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
+
+ if edit_momentum is None:
+ edit_momentum = torch.zeros_like(noise_guidance)
+
+ if enable_edit_guidance:
+ concept_weights = torch.zeros(
+ (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
+ device=self.device,
+ dtype=noise_guidance.dtype,
+ )
+ noise_guidance_edit = torch.zeros(
+ (len(noise_pred_edit_concepts), *noise_guidance.shape),
+ device=self.device,
+ dtype=noise_guidance.dtype,
+ )
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
+ warmup_inds = []
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
+ self.edit_estimates[i, c] = noise_pred_edit_concept
+ if isinstance(edit_guidance_scale, list):
+ edit_guidance_scale_c = edit_guidance_scale[c]
+ else:
+ edit_guidance_scale_c = edit_guidance_scale
+
+ if isinstance(edit_threshold, list):
+ edit_threshold_c = edit_threshold[c]
+ else:
+ edit_threshold_c = edit_threshold
+ if isinstance(reverse_editing_direction, list):
+ reverse_editing_direction_c = reverse_editing_direction[c]
+ else:
+ reverse_editing_direction_c = reverse_editing_direction
+ if edit_weights:
+ edit_weight_c = edit_weights[c]
+ else:
+ edit_weight_c = 1.0
+ if isinstance(edit_warmup_steps, list):
+ edit_warmup_steps_c = edit_warmup_steps[c]
+ else:
+ edit_warmup_steps_c = edit_warmup_steps
+
+ if isinstance(edit_cooldown_steps, list):
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
+ elif edit_cooldown_steps is None:
+ edit_cooldown_steps_c = i + 1
+ else:
+ edit_cooldown_steps_c = edit_cooldown_steps
+ if i >= edit_warmup_steps_c:
+ warmup_inds.append(c)
+ if i >= edit_cooldown_steps_c:
+ noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
+ continue
+
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
+ # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
+ tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
+
+ tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
+ if reverse_editing_direction_c:
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
+ concept_weights[c, :] = tmp_weights
+
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
+
+ # torch.quantile function expects float32
+ if noise_guidance_edit_tmp.dtype == torch.float32:
+ tmp = torch.quantile(
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),
+ edit_threshold_c,
+ dim=2,
+ keepdim=False,
+ )
+ else:
+ tmp = torch.quantile(
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),
+ edit_threshold_c,
+ dim=2,
+ keepdim=False,
+ ).to(noise_guidance_edit_tmp.dtype)
+
+ noise_guidance_edit_tmp = torch.where(
+ torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None],
+ noise_guidance_edit_tmp,
+ torch.zeros_like(noise_guidance_edit_tmp),
+ )
+ noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
+
+ # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
+
+ warmup_inds = torch.tensor(warmup_inds).to(self.device)
+ if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
+ concept_weights = concept_weights.to("cpu") # Offload to cpu
+ noise_guidance_edit = noise_guidance_edit.to("cpu")
+
+ concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
+ concept_weights_tmp = torch.where(
+ concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
+ )
+ concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
+ # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
+
+ noise_guidance_edit_tmp = torch.index_select(
+ noise_guidance_edit.to(self.device), 0, warmup_inds
+ )
+ noise_guidance_edit_tmp = torch.einsum(
+ "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
+ )
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp
+ noise_guidance = noise_guidance + noise_guidance_edit_tmp
+
+ self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
+
+ del noise_guidance_edit_tmp
+ del concept_weights_tmp
+ concept_weights = concept_weights.to(self.device)
+ noise_guidance_edit = noise_guidance_edit.to(self.device)
+
+ concept_weights = torch.where(
+ concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
+ )
+
+ concept_weights = torch.nan_to_num(concept_weights)
+
+ noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
+
+ noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
+
+ edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
+
+ if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
+ noise_guidance = noise_guidance + noise_guidance_edit
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
+
+ if sem_guidance is not None:
+ edit_guidance = sem_guidance[i].to(self.device)
+ noise_guidance = noise_guidance + edit_guidance
+
+ noise_pred = noise_pred_uncond + noise_guidance
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/shap_e/__init__.py b/diffusers/pipelines/shap_e/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04aa1f2f6d7852877e4c7f8b07cd15a8d1d496f5
--- /dev/null
+++ b/diffusers/pipelines/shap_e/__init__.py
@@ -0,0 +1,27 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
+else:
+ from .camera import create_pan_cameras
+ from .pipeline_shap_e import ShapEPipeline
+ from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline
+ from .renderer import (
+ BoundingBoxVolume,
+ ImportanceRaySampler,
+ MLPNeRFModelOutput,
+ MLPNeRSTFModel,
+ ShapEParamsProjModel,
+ ShapERenderer,
+ StratifiedRaySampler,
+ VoidNeRFModel,
+ )
diff --git a/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a79af1c2b64a6bfd6b8a00f69c2ea209434552eb
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71bb328f3133c9297127006c84e0190dc490c853
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3271468d32ceff0cfdd4baf905e42bd1eb50da45
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-310.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54a30e960a8dde49f46b2194739af5ad2fa38e57
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/camera.cpython-38.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12f4d8adf926e89e7790347a075c258965eff465
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-310.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..930a2d9566c7b4ca233cdb3cef903baaba357e50
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e.cpython-38.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2d7ae7e7f1d0e3d91249a0a7dc4c88893db54f1
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb320ee0b0b8de9dd76886e0451513550d65256d
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/pipeline_shap_e_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2540fc9bf3b0aff5f07a52fe9c768cc903d917f
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-310.pyc differ
diff --git a/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca37932bafab2a20951d4df7dc3a4970d046b2d3
Binary files /dev/null and b/diffusers/pipelines/shap_e/__pycache__/renderer.cpython-38.pyc differ
diff --git a/diffusers/pipelines/shap_e/camera.py b/diffusers/pipelines/shap_e/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ef0d66070223a80eed59da8d842389fed0c7aef
--- /dev/null
+++ b/diffusers/pipelines/shap_e/camera.py
@@ -0,0 +1,147 @@
+# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Tuple
+
+import numpy as np
+import torch
+
+
+@dataclass
+class DifferentiableProjectiveCamera:
+ """
+ Implements a batch, differentiable, standard pinhole camera
+ """
+
+ origin: torch.Tensor # [batch_size x 3]
+ x: torch.Tensor # [batch_size x 3]
+ y: torch.Tensor # [batch_size x 3]
+ z: torch.Tensor # [batch_size x 3]
+ width: int
+ height: int
+ x_fov: float
+ y_fov: float
+ shape: Tuple[int]
+
+ def __post_init__(self):
+ assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
+ assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
+ assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
+
+ def resolution(self):
+ return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
+
+ def fov(self):
+ return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
+
+ def get_image_coords(self) -> torch.Tensor:
+ """
+ :return: coords of shape (width * height, 2)
+ """
+ pixel_indices = torch.arange(self.height * self.width)
+ coords = torch.stack(
+ [
+ pixel_indices % self.width,
+ torch.div(pixel_indices, self.width, rounding_mode="trunc"),
+ ],
+ axis=1,
+ )
+ return coords
+
+ @property
+ def camera_rays(self):
+ batch_size, *inner_shape = self.shape
+ inner_batch_size = int(np.prod(inner_shape))
+
+ coords = self.get_image_coords()
+ coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
+ rays = self.get_camera_rays(coords)
+
+ rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
+
+ return rays
+
+ def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
+ batch_size, *shape, n_coords = coords.shape
+ assert n_coords == 2
+ assert batch_size == self.origin.shape[0]
+
+ flat = coords.view(batch_size, -1, 2)
+
+ res = self.resolution()
+ fov = self.fov()
+
+ fracs = (flat.float() / (res - 1)) * 2 - 1
+ fracs = fracs * torch.tan(fov / 2)
+
+ fracs = fracs.view(batch_size, -1, 2)
+ directions = (
+ self.z.view(batch_size, 1, 3)
+ + self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
+ + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
+ )
+ directions = directions / directions.norm(dim=-1, keepdim=True)
+ rays = torch.stack(
+ [
+ torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
+ directions,
+ ],
+ dim=2,
+ )
+ return rays.view(batch_size, *shape, 2, 3)
+
+ def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
+ """
+ Creates a new camera for the resized view assuming the aspect ratio does not change.
+ """
+ assert width * self.height == height * self.width, "The aspect ratio should not change."
+ return DifferentiableProjectiveCamera(
+ origin=self.origin,
+ x=self.x,
+ y=self.y,
+ z=self.z,
+ width=width,
+ height=height,
+ x_fov=self.x_fov,
+ y_fov=self.y_fov,
+ )
+
+
+def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
+ origins = []
+ xs = []
+ ys = []
+ zs = []
+ for theta in np.linspace(0, 2 * np.pi, num=20):
+ z = np.array([np.sin(theta), np.cos(theta), -0.5])
+ z /= np.sqrt(np.sum(z**2))
+ origin = -z * 4
+ x = np.array([np.cos(theta), -np.sin(theta), 0.0])
+ y = np.cross(z, x)
+ origins.append(origin)
+ xs.append(x)
+ ys.append(y)
+ zs.append(z)
+ return DifferentiableProjectiveCamera(
+ origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
+ x=torch.from_numpy(np.stack(xs, axis=0)).float(),
+ y=torch.from_numpy(np.stack(ys, axis=0)).float(),
+ z=torch.from_numpy(np.stack(zs, axis=0)).float(),
+ width=size,
+ height=size,
+ x_fov=0.7,
+ y_fov=0.7,
+ shape=(1, len(xs)),
+ )
diff --git a/diffusers/pipelines/shap_e/pipeline_shap_e.py b/diffusers/pipelines/shap_e/pipeline_shap_e.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdcbe550860a403d2869321affd9045249472dad
--- /dev/null
+++ b/diffusers/pipelines/shap_e/pipeline_shap_e.py
@@ -0,0 +1,354 @@
+# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...models import PriorTransformer
+from ...pipelines import DiffusionPipeline
+from ...schedulers import HeunDiscreteScheduler
+from ...utils import (
+ BaseOutput,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from .renderer import ShapERenderer
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from diffusers.utils import export_to_gif
+
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ >>> repo = "openai/shap-e"
+ >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> guidance_scale = 15.0
+ >>> prompt = "a shark"
+
+ >>> images = pipe(
+ ... prompt,
+ ... guidance_scale=guidance_scale,
+ ... num_inference_steps=64,
+ ... frame_size=256,
+ ... ).images
+
+ >>> gif_path = export_to_gif(images[0], "shark_3d.gif")
+ ```
+"""
+
+
+@dataclass
+class ShapEPipelineOutput(BaseOutput):
+ """
+ Output class for ShapEPipeline.
+
+ Args:
+ images (`torch.FloatTensor`)
+ a list of images for 3D rendering
+ """
+
+ images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]]
+
+
+class ShapEPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`HeunDiscreteScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ renderer ([`ShapERenderer`]):
+ Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
+ with the NeRF rendering method
+ """
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ scheduler: HeunDiscreteScheduler,
+ renderer: ShapERenderer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ renderer=renderer,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ len(prompt) if isinstance(prompt, list) else 1
+
+ # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file
+ self.tokenizer.pad_token_id = 0
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+ prompt_embeds = text_encoder_output.text_embeds
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ # in Shap-E it normalize the prompt_embeds and then later rescale it
+ prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # Rescale the features to have unit variance
+ prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds
+
+ return prompt_embeds
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: str,
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ guidance_scale: float = 4.0,
+ frame_size: int = 64,
+ output_type: Optional[str] = "pil", # pil, np, latent
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ frame_size (`int`, *optional*, default to 64):
+ the width and height of each image frame of the generated 3d output
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
+ (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`ShapEPipelineOutput`] or `tuple`
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
+
+ # prior
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ num_embeddings = self.prior.config.num_embeddings
+ embedding_dim = self.prior.config.embedding_dim
+
+ latents = self.prepare_latents(
+ (batch_size, num_embeddings * embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
+ latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.prior(
+ scaled_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ ).predicted_image_embedding
+
+ # remove the variance
+ noise_pred, _ = noise_pred.split(
+ scaled_model_input.shape[2], dim=2
+ ) # batch_size, num_embeddings, embedding_dim
+
+ if do_classifier_free_guidance is not None:
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ latents = self.scheduler.step(
+ noise_pred,
+ timestep=t,
+ sample=latents,
+ ).prev_sample
+
+ if output_type == "latent":
+ return ShapEPipelineOutput(images=latents)
+
+ images = []
+ for i, latent in enumerate(latents):
+ image = self.renderer.decode(
+ latent[None, :],
+ device,
+ size=frame_size,
+ ray_batch_size=4096,
+ n_coarse_samples=64,
+ n_fine_samples=128,
+ )
+ images.append(image)
+
+ images = torch.stack(images)
+
+ if output_type not in ["np", "pil"]:
+ raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
+
+ images = images.cpu().numpy()
+
+ if output_type == "pil":
+ images = [self.numpy_to_pil(image) for image in images]
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (images,)
+
+ return ShapEPipelineOutput(images=images)
diff --git a/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c585c5ad736d6f12397409187dc77659b771e3
--- /dev/null
+++ b/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
@@ -0,0 +1,312 @@
+# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModel
+
+from ...models import PriorTransformer
+from ...pipelines import DiffusionPipeline
+from ...schedulers import HeunDiscreteScheduler
+from ...utils import (
+ BaseOutput,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from .renderer import ShapERenderer
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from PIL import Image
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from diffusers.utils import export_to_gif, load_image
+
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ >>> repo = "openai/shap-e-img2img"
+ >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> guidance_scale = 3.0
+ >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png"
+ >>> image = load_image(image_url).convert("RGB")
+
+ >>> images = pipe(
+ ... image,
+ ... guidance_scale=guidance_scale,
+ ... num_inference_steps=64,
+ ... frame_size=256,
+ ... ).images
+
+ >>> gif_path = export_to_gif(images[0], "corgi_3d.gif")
+ ```
+"""
+
+
+@dataclass
+class ShapEPipelineOutput(BaseOutput):
+ """
+ Output class for ShapEPipeline.
+
+ Args:
+ images (`torch.FloatTensor`)
+ a list of images for 3D rendering
+ """
+
+ images: Union[PIL.Image.Image, np.ndarray]
+
+
+class ShapEImg2ImgPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`HeunDiscreteScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ renderer ([`ShapERenderer`]):
+ Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
+ with the NeRF rendering method
+ """
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ scheduler: HeunDiscreteScheduler,
+ renderer: ShapERenderer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ scheduler=scheduler,
+ renderer=renderer,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_image(
+ self,
+ image,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ if isinstance(image, List) and isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
+
+ if not isinstance(image, torch.Tensor):
+ image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0)
+
+ image = image.to(dtype=self.image_encoder.dtype, device=device)
+
+ image_embeds = self.image_encoder(image)["last_hidden_state"]
+ image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256
+
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.zeros_like(image_embeds)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ return image_embeds
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ num_images_per_prompt: int = 1,
+ num_inference_steps: int = 25,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ guidance_scale: float = 4.0,
+ frame_size: int = 64,
+ output_type: Optional[str] = "pil", # pil, np, latent
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ frame_size (`int`, *optional*, default to 64):
+ the width and height of each image frame of the generated 3d output
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
+ (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`ShapEPipelineOutput`] or `tuple`
+ """
+
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, torch.Tensor):
+ batch_size = image.shape[0]
+ elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)):
+ batch_size = len(image)
+ else:
+ raise ValueError(
+ f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}"
+ )
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+ image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
+
+ # prior
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ num_embeddings = self.prior.config.num_embeddings
+ embedding_dim = self.prior.config.embedding_dim
+
+ latents = self.prepare_latents(
+ (batch_size, num_embeddings * embedding_dim),
+ image_embeds.dtype,
+ device,
+ generator,
+ latents,
+ self.scheduler,
+ )
+
+ # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
+ latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ noise_pred = self.prior(
+ scaled_model_input,
+ timestep=t,
+ proj_embedding=image_embeds,
+ ).predicted_image_embedding
+
+ # remove the variance
+ noise_pred, _ = noise_pred.split(
+ scaled_model_input.shape[2], dim=2
+ ) # batch_size, num_embeddings, embedding_dim
+
+ if do_classifier_free_guidance is not None:
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ latents = self.scheduler.step(
+ noise_pred,
+ timestep=t,
+ sample=latents,
+ ).prev_sample
+
+ if output_type == "latent":
+ return ShapEPipelineOutput(images=latents)
+
+ images = []
+ for i, latent in enumerate(latents):
+ print()
+ image = self.renderer.decode(
+ latent[None, :],
+ device,
+ size=frame_size,
+ ray_batch_size=4096,
+ n_coarse_samples=64,
+ n_fine_samples=128,
+ )
+
+ images.append(image)
+
+ images = torch.stack(images)
+
+ if output_type not in ["np", "pil"]:
+ raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
+
+ images = images.cpu().numpy()
+
+ if output_type == "pil":
+ images = [self.numpy_to_pil(image) for image in images]
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (images,)
+
+ return ShapEPipelineOutput(images=images)
diff --git a/diffusers/pipelines/shap_e/renderer.py b/diffusers/pipelines/shap_e/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b075e671f63d9f6cbddcfb205df1ba38a426e6f
--- /dev/null
+++ b/diffusers/pipelines/shap_e/renderer.py
@@ -0,0 +1,709 @@
+# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+from ...utils import BaseOutput
+from .camera import create_pan_cameras
+
+
+def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor:
+ r"""
+ Sample from the given discrete probability distribution with replacement.
+
+ The i-th bin is assumed to have mass pmf[i].
+
+ Args:
+ pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all()
+ n_samples: number of samples
+
+ Return:
+ indices sampled with replacement
+ """
+
+ *shape, support_size, last_dim = pmf.shape
+ assert last_dim == 1
+
+ cdf = torch.cumsum(pmf.view(-1, support_size), dim=1)
+ inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device))
+
+ return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1)
+
+
+def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor:
+ """
+ Concatenate x and its positional encodings, following NeRF.
+
+ Reference: https://arxiv.org/pdf/2210.04628.pdf
+ """
+ if min_deg == max_deg:
+ return x
+
+ scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device)
+ *shape, dim = x.shape
+ xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1)
+ assert xb.shape[-1] == dim * (max_deg - min_deg)
+ emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin()
+ return torch.cat([x, emb], dim=-1)
+
+
+def encode_position(position):
+ return posenc_nerf(position, min_deg=0, max_deg=15)
+
+
+def encode_direction(position, direction=None):
+ if direction is None:
+ return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8))
+ else:
+ return posenc_nerf(direction, min_deg=0, max_deg=8)
+
+
+def _sanitize_name(x: str) -> str:
+ return x.replace(".", "__")
+
+
+def integrate_samples(volume_range, ts, density, channels):
+ r"""
+ Function integrating the model output.
+
+ Args:
+ volume_range: Specifies the integral range [t0, t1]
+ ts: timesteps
+ density: torch.Tensor [batch_size, *shape, n_samples, 1]
+ channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
+ returns:
+ channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density
+ *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume
+ )
+ """
+
+ # 1. Calculate the weights
+ _, _, dt = volume_range.partition(ts)
+ ddensity = density * dt
+
+ mass = torch.cumsum(ddensity, dim=-2)
+ transmittance = torch.exp(-mass[..., -1, :])
+
+ alphas = 1.0 - torch.exp(-ddensity)
+ Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
+ # This is the probability of light hitting and reflecting off of
+ # something at depth [..., i, :].
+ weights = alphas * Ts
+
+ # 2. Integrate channels
+ channels = torch.sum(channels * weights, dim=-2)
+
+ return channels, weights, transmittance
+
+
+class VoidNeRFModel(nn.Module):
+ """
+ Implements the default empty space model where all queries are rendered as background.
+ """
+
+ def __init__(self, background, channel_scale=255.0):
+ super().__init__()
+ background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale)
+
+ self.register_buffer("background", background)
+
+ def forward(self, position):
+ background = self.background[None].to(position.device)
+
+ shape = position.shape[:-1]
+ ones = [1] * (len(shape) - 1)
+ n_channels = background.shape[-1]
+ background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels])
+
+ return background
+
+
+@dataclass
+class VolumeRange:
+ t0: torch.Tensor
+ t1: torch.Tensor
+ intersected: torch.Tensor
+
+ def __post_init__(self):
+ assert self.t0.shape == self.t1.shape == self.intersected.shape
+
+ def partition(self, ts):
+ """
+ Partitions t0 and t1 into n_samples intervals.
+
+ Args:
+ ts: [batch_size, *shape, n_samples, 1]
+
+ Return:
+
+ lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size,
+ *shape, n_samples, 1]
+
+ where
+ ts \\in [lower, upper] deltas = upper - lower
+ """
+
+ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
+ lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
+ upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
+ delta = upper - lower
+ assert lower.shape == upper.shape == delta.shape == ts.shape
+ return lower, upper, delta
+
+
+class BoundingBoxVolume(nn.Module):
+ """
+ Axis-aligned bounding box defined by the two opposite corners.
+ """
+
+ def __init__(
+ self,
+ *,
+ bbox_min,
+ bbox_max,
+ min_dist: float = 0.0,
+ min_t_range: float = 1e-3,
+ ):
+ """
+ Args:
+ bbox_min: the left/bottommost corner of the bounding box
+ bbox_max: the other corner of the bounding box
+ min_dist: all rays should start at least this distance away from the origin.
+ """
+ super().__init__()
+
+ self.min_dist = min_dist
+ self.min_t_range = min_t_range
+
+ self.bbox_min = torch.tensor(bbox_min)
+ self.bbox_max = torch.tensor(bbox_max)
+ self.bbox = torch.stack([self.bbox_min, self.bbox_max])
+ assert self.bbox.shape == (2, 3)
+ assert min_dist >= 0.0
+ assert min_t_range > 0.0
+
+ def intersect(
+ self,
+ origin: torch.Tensor,
+ direction: torch.Tensor,
+ t0_lower: Optional[torch.Tensor] = None,
+ epsilon=1e-6,
+ ):
+ """
+ Args:
+ origin: [batch_size, *shape, 3]
+ direction: [batch_size, *shape, 3]
+ t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
+ params: Optional meta parameters in case Volume is parametric
+ epsilon: to stabilize calculations
+
+ Return:
+ A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with
+ the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to
+ be on the boundary of the volume.
+ """
+
+ batch_size, *shape, _ = origin.shape
+ ones = [1] * len(shape)
+ bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device)
+
+ def _safe_divide(a, b, epsilon=1e-6):
+ return a / torch.where(b < 0, b - epsilon, b + epsilon)
+
+ ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)
+
+ # Cases to think about:
+ #
+ # 1. t1 <= t0: the ray does not pass through the AABB.
+ # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
+ # 3. t0 <= 0 <= t1: the ray starts from inside the BB
+ # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
+ #
+ # 1 and 4 are clearly handled from t0 < t1 below.
+ # Making t0 at least min_dist (>= 0) takes care of 2 and 3.
+ t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
+ t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
+ assert t0.shape == t1.shape == (batch_size, *shape, 1)
+ if t0_lower is not None:
+ assert t0.shape == t0_lower.shape
+ t0 = torch.maximum(t0, t0_lower)
+
+ intersected = t0 + self.min_t_range < t1
+ t0 = torch.where(intersected, t0, torch.zeros_like(t0))
+ t1 = torch.where(intersected, t1, torch.ones_like(t1))
+
+ return VolumeRange(t0=t0, t1=t1, intersected=intersected)
+
+
+class StratifiedRaySampler(nn.Module):
+ """
+ Instead of fixed intervals, a sample is drawn uniformly at random from each interval.
+ """
+
+ def __init__(self, depth_mode: str = "linear"):
+ """
+ :param depth_mode: linear samples ts linearly in depth. harmonic ensures
+ closer points are sampled more densely.
+ """
+ self.depth_mode = depth_mode
+ assert self.depth_mode in ("linear", "geometric", "harmonic")
+
+ def sample(
+ self,
+ t0: torch.Tensor,
+ t1: torch.Tensor,
+ n_samples: int,
+ epsilon: float = 1e-3,
+ ) -> torch.Tensor:
+ """
+ Args:
+ t0: start time has shape [batch_size, *shape, 1]
+ t1: finish time has shape [batch_size, *shape, 1]
+ n_samples: number of ts to sample
+ Return:
+ sampled ts of shape [batch_size, *shape, n_samples, 1]
+ """
+ ones = [1] * (len(t0.shape) - 1)
+ ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
+
+ if self.depth_mode == "linear":
+ ts = t0 * (1.0 - ts) + t1 * ts
+ elif self.depth_mode == "geometric":
+ ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
+ elif self.depth_mode == "harmonic":
+ # The original NeRF recommends this interpolation scheme for
+ # spherical scenes, but there could be some weird edge cases when
+ # the observer crosses from the inner to outer volume.
+ ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
+
+ mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
+ upper = torch.cat([mids, t1], dim=-1)
+ lower = torch.cat([t0, mids], dim=-1)
+ # yiyi notes: add a random seed here for testing, don't forget to remove
+ torch.manual_seed(0)
+ t_rand = torch.rand_like(ts)
+
+ ts = lower + (upper - lower) * t_rand
+ return ts.unsqueeze(-1)
+
+
+class ImportanceRaySampler(nn.Module):
+ """
+ Given the initial estimate of densities, this samples more from regions/bins expected to have objects.
+ """
+
+ def __init__(
+ self,
+ volume_range: VolumeRange,
+ ts: torch.Tensor,
+ weights: torch.Tensor,
+ blur_pool: bool = False,
+ alpha: float = 1e-5,
+ ):
+ """
+ Args:
+ volume_range: the range in which a ray intersects the given volume.
+ ts: earlier samples from the coarse rendering step
+ weights: discretized version of density * transmittance
+ blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
+ alpha: small value to add to weights.
+ """
+ self.volume_range = volume_range
+ self.ts = ts.clone().detach()
+ self.weights = weights.clone().detach()
+ self.blur_pool = blur_pool
+ self.alpha = alpha
+
+ @torch.no_grad()
+ def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
+ """
+ Args:
+ t0: start time has shape [batch_size, *shape, 1]
+ t1: finish time has shape [batch_size, *shape, 1]
+ n_samples: number of ts to sample
+ Return:
+ sampled ts of shape [batch_size, *shape, n_samples, 1]
+ """
+ lower, upper, _ = self.volume_range.partition(self.ts)
+
+ batch_size, *shape, n_coarse_samples, _ = self.ts.shape
+
+ weights = self.weights
+ if self.blur_pool:
+ padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
+ maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
+ weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
+ weights = weights + self.alpha
+ pmf = weights / weights.sum(dim=-2, keepdim=True)
+ inds = sample_pmf(pmf, n_samples)
+ assert inds.shape == (batch_size, *shape, n_samples, 1)
+ assert (inds >= 0).all() and (inds < n_coarse_samples).all()
+
+ t_rand = torch.rand(inds.shape, device=inds.device)
+ lower_ = torch.gather(lower, -2, inds)
+ upper_ = torch.gather(upper, -2, inds)
+
+ ts = lower_ + (upper_ - lower_) * t_rand
+ ts = torch.sort(ts, dim=-2).values
+ return ts
+
+
+@dataclass
+class MLPNeRFModelOutput(BaseOutput):
+ density: torch.Tensor
+ signed_distance: torch.Tensor
+ channels: torch.Tensor
+ ts: torch.Tensor
+
+
+class MLPNeRSTFModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ d_hidden: int = 256,
+ n_output: int = 12,
+ n_hidden_layers: int = 6,
+ act_fn: str = "swish",
+ insert_direction_at: int = 4,
+ ):
+ super().__init__()
+
+ # Instantiate the MLP
+
+ # Find out the dimension of encoded position and direction
+ dummy = torch.eye(1, 3)
+ d_posenc_pos = encode_position(position=dummy).shape[-1]
+ d_posenc_dir = encode_direction(position=dummy).shape[-1]
+
+ mlp_widths = [d_hidden] * n_hidden_layers
+ input_widths = [d_posenc_pos] + mlp_widths
+ output_widths = mlp_widths + [n_output]
+
+ if insert_direction_at is not None:
+ input_widths[insert_direction_at] += d_posenc_dir
+
+ self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)])
+
+ if act_fn == "swish":
+ # self.activation = swish
+ # yiyi testing:
+ self.activation = lambda x: F.silu(x)
+ else:
+ raise ValueError(f"Unsupported activation function {act_fn}")
+
+ self.sdf_activation = torch.tanh
+ self.density_activation = torch.nn.functional.relu
+ self.channel_activation = torch.sigmoid
+
+ def map_indices_to_keys(self, output):
+ h_map = {
+ "sdf": (0, 1),
+ "density_coarse": (1, 2),
+ "density_fine": (2, 3),
+ "stf": (3, 6),
+ "nerf_coarse": (6, 9),
+ "nerf_fine": (9, 12),
+ }
+
+ mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()}
+
+ return mapped_output
+
+ def forward(self, *, position, direction, ts, nerf_level="coarse"):
+ h = encode_position(position)
+
+ h_preact = h
+ h_directionless = None
+ for i, layer in enumerate(self.mlp):
+ if i == self.config.insert_direction_at: # 4 in the config
+ h_directionless = h_preact
+ h_direction = encode_direction(position, direction=direction)
+ h = torch.cat([h, h_direction], dim=-1)
+
+ h = layer(h)
+
+ h_preact = h
+
+ if i < len(self.mlp) - 1:
+ h = self.activation(h)
+
+ h_final = h
+ if h_directionless is None:
+ h_directionless = h_preact
+
+ activation = self.map_indices_to_keys(h_final)
+
+ if nerf_level == "coarse":
+ h_density = activation["density_coarse"]
+ h_channels = activation["nerf_coarse"]
+ else:
+ h_density = activation["density_fine"]
+ h_channels = activation["nerf_fine"]
+
+ density = self.density_activation(h_density)
+ signed_distance = self.sdf_activation(activation["sdf"])
+ channels = self.channel_activation(h_channels)
+
+ # yiyi notes: I think signed_distance is not used
+ return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts)
+
+
+class ChannelsProj(nn.Module):
+ def __init__(
+ self,
+ *,
+ vectors: int,
+ channels: int,
+ d_latent: int,
+ ):
+ super().__init__()
+ self.proj = nn.Linear(d_latent, vectors * channels)
+ self.norm = nn.LayerNorm(channels)
+ self.d_latent = d_latent
+ self.vectors = vectors
+ self.channels = channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_bvd = x
+ w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent)
+ b_vc = self.proj.bias.view(1, self.vectors, self.channels)
+ h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd)
+ h = self.norm(h)
+
+ h = h + b_vc
+ return h
+
+
+class ShapEParamsProjModel(ModelMixin, ConfigMixin):
+ """
+ project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP).
+
+ For more details, see the original paper:
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ *,
+ param_names: Tuple[str] = (
+ "nerstf.mlp.0.weight",
+ "nerstf.mlp.1.weight",
+ "nerstf.mlp.2.weight",
+ "nerstf.mlp.3.weight",
+ ),
+ param_shapes: Tuple[Tuple[int]] = (
+ (256, 93),
+ (256, 256),
+ (256, 256),
+ (256, 256),
+ ),
+ d_latent: int = 1024,
+ ):
+ super().__init__()
+
+ # check inputs
+ if len(param_names) != len(param_shapes):
+ raise ValueError("Must provide same number of `param_names` as `param_shapes`")
+ self.projections = nn.ModuleDict({})
+ for k, (vectors, channels) in zip(param_names, param_shapes):
+ self.projections[_sanitize_name(k)] = ChannelsProj(
+ vectors=vectors,
+ channels=channels,
+ d_latent=d_latent,
+ )
+
+ def forward(self, x: torch.Tensor):
+ out = {}
+ start = 0
+ for k, shape in zip(self.config.param_names, self.config.param_shapes):
+ vectors, _ = shape
+ end = start + vectors
+ x_bvd = x[:, start:end]
+ out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape)
+ start = end
+ return out
+
+
+class ShapERenderer(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ *,
+ param_names: Tuple[str] = (
+ "nerstf.mlp.0.weight",
+ "nerstf.mlp.1.weight",
+ "nerstf.mlp.2.weight",
+ "nerstf.mlp.3.weight",
+ ),
+ param_shapes: Tuple[Tuple[int]] = (
+ (256, 93),
+ (256, 256),
+ (256, 256),
+ (256, 256),
+ ),
+ d_latent: int = 1024,
+ d_hidden: int = 256,
+ n_output: int = 12,
+ n_hidden_layers: int = 6,
+ act_fn: str = "swish",
+ insert_direction_at: int = 4,
+ background: Tuple[float] = (
+ 255.0,
+ 255.0,
+ 255.0,
+ ),
+ ):
+ super().__init__()
+
+ self.params_proj = ShapEParamsProjModel(
+ param_names=param_names,
+ param_shapes=param_shapes,
+ d_latent=d_latent,
+ )
+ self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
+ self.void = VoidNeRFModel(background=background, channel_scale=255.0)
+ self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
+
+ @torch.no_grad()
+ def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
+ """
+ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below
+ with some abuse of notations)
+
+ C(r) := sum(
+ transmittance(t[i]) * integrate(
+ lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]],
+ ) for i in range(len(parts))
+ ) + transmittance(t[-1]) * void_model(t[-1]).channels
+
+ where
+
+ 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through
+ the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are
+ obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t
+ where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the
+ shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and
+ transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1],
+ math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
+
+ args:
+ rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples:
+ number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including
+
+ :return: A tuple of
+ - `channels`
+ - A importance samplers for additional fine-grained rendering
+ - raw model output
+ """
+ origin, direction = rays[..., 0, :], rays[..., 1, :]
+
+ # Integrate over [t[i], t[i + 1]]
+
+ # 1 Intersect the rays with the current volume and sample ts to integrate along.
+ vrange = self.volume.intersect(origin, direction, t0_lower=None)
+ ts = sampler.sample(vrange.t0, vrange.t1, n_samples)
+ ts = ts.to(rays.dtype)
+
+ if prev_model_out is not None:
+ # Append the previous ts now before fprop because previous
+ # rendering used a different model and we can't reuse the output.
+ ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values
+
+ batch_size, *_shape, _t0_dim = vrange.t0.shape
+ _, *ts_shape, _ts_dim = ts.shape
+
+ # 2. Get the points along the ray and query the model
+ directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
+ positions = origin.unsqueeze(-2) + ts * directions
+
+ directions = directions.to(self.mlp.dtype)
+ positions = positions.to(self.mlp.dtype)
+
+ optional_directions = directions if render_with_direction else None
+
+ model_out = self.mlp(
+ position=positions,
+ direction=optional_directions,
+ ts=ts,
+ nerf_level="coarse" if prev_model_out is None else "fine",
+ )
+
+ # 3. Integrate the model results
+ channels, weights, transmittance = integrate_samples(
+ vrange, model_out.ts, model_out.density, model_out.channels
+ )
+
+ # 4. Clean up results that do not intersect with the volume.
+ transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance))
+ channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels))
+ # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty).
+ channels = channels + transmittance * self.void(origin)
+
+ weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights)
+
+ return channels, weighted_sampler, model_out
+
+ @torch.no_grad()
+ def decode(
+ self,
+ latents,
+ device,
+ size: int = 64,
+ ray_batch_size: int = 4096,
+ n_coarse_samples=64,
+ n_fine_samples=128,
+ ):
+ # project the the paramters from the generated latents
+ projected_params = self.params_proj(latents)
+
+ # update the mlp layers of the renderer
+ for name, param in self.mlp.state_dict().items():
+ if f"nerstf.{name}" in projected_params.keys():
+ param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
+
+ # create cameras object
+ camera = create_pan_cameras(size)
+ rays = camera.camera_rays
+ rays = rays.to(device)
+ n_batches = rays.shape[1] // ray_batch_size
+
+ coarse_sampler = StratifiedRaySampler()
+
+ images = []
+
+ for idx in range(n_batches):
+ rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size]
+
+ # render rays with coarse, stratified samples.
+ _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples)
+ # Then, render with additional importance-weighted ray samples.
+ channels, _, _ = self.render_rays(
+ rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out
+ )
+
+ images.append(channels)
+
+ images = torch.cat(images, dim=1)
+ images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
+
+ return images
diff --git a/diffusers/pipelines/spectrogram_diffusion/__init__.py b/diffusers/pipelines/spectrogram_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b14a857630e7a7c001a8ae4c23772dfc62a08a
--- /dev/null
+++ b/diffusers/pipelines/spectrogram_diffusion/__init__.py
@@ -0,0 +1,26 @@
+# flake8: noqa
+from ...utils import is_note_seq_available, is_transformers_available, is_torch_available
+from ...utils import OptionalDependencyNotAvailable
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .notes_encoder import SpectrogramNotesEncoder
+ from .continous_encoder import SpectrogramContEncoder
+ from .pipeline_spectrogram_diffusion import (
+ SpectrogramContEncoder,
+ SpectrogramDiffusionPipeline,
+ T5FilmDecoder,
+ )
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
+else:
+ from .midi_utils import MidiProcessor
diff --git a/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py b/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..556136d4023df32e4df2477523463829a0722db4
--- /dev/null
+++ b/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py
@@ -0,0 +1,92 @@
+# Copyright 2022 The Music Spectrogram Diffusion Authors.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from transformers.modeling_utils import ModuleUtilsMixin
+from transformers.models.t5.modeling_t5 import (
+ T5Block,
+ T5Config,
+ T5LayerNorm,
+)
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+
+
+class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
+ @register_to_config
+ def __init__(
+ self,
+ input_dims: int,
+ targets_context_length: int,
+ d_model: int,
+ dropout_rate: float,
+ num_layers: int,
+ num_heads: int,
+ d_kv: int,
+ d_ff: int,
+ feed_forward_proj: str,
+ is_decoder: bool = False,
+ ):
+ super().__init__()
+
+ self.input_proj = nn.Linear(input_dims, d_model, bias=False)
+
+ self.position_encoding = nn.Embedding(targets_context_length, d_model)
+ self.position_encoding.weight.requires_grad = False
+
+ self.dropout_pre = nn.Dropout(p=dropout_rate)
+
+ t5config = T5Config(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ d_ff=d_ff,
+ feed_forward_proj=feed_forward_proj,
+ dropout_rate=dropout_rate,
+ is_decoder=is_decoder,
+ is_encoder_decoder=False,
+ )
+ self.encoders = nn.ModuleList()
+ for lyr_num in range(num_layers):
+ lyr = T5Block(t5config)
+ self.encoders.append(lyr)
+
+ self.layer_norm = T5LayerNorm(d_model)
+ self.dropout_post = nn.Dropout(p=dropout_rate)
+
+ def forward(self, encoder_inputs, encoder_inputs_mask):
+ x = self.input_proj(encoder_inputs)
+
+ # terminal relative positional encodings
+ max_positions = encoder_inputs.shape[1]
+ input_positions = torch.arange(max_positions, device=encoder_inputs.device)
+
+ seq_lens = encoder_inputs_mask.sum(-1)
+ input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0)
+ x += self.position_encoding(input_positions)
+
+ x = self.dropout_pre(x)
+
+ # inverted the attention mask
+ input_shape = encoder_inputs.size()
+ extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
+
+ for lyr in self.encoders:
+ x = lyr(x, extended_attention_mask)[0]
+ x = self.layer_norm(x)
+
+ return self.dropout_post(x), encoder_inputs_mask
diff --git a/diffusers/pipelines/spectrogram_diffusion/midi_utils.py b/diffusers/pipelines/spectrogram_diffusion/midi_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d0878db588aa38a2e602a3bc5f6505b9457575
--- /dev/null
+++ b/diffusers/pipelines/spectrogram_diffusion/midi_utils.py
@@ -0,0 +1,667 @@
+# Copyright 2022 The Music Spectrogram Diffusion Authors.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import math
+import os
+from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ...utils import is_note_seq_available
+from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH
+
+
+if is_note_seq_available():
+ import note_seq
+else:
+ raise ImportError("Please install note-seq via `pip install note-seq`")
+
+
+INPUT_FEATURE_LENGTH = 2048
+
+SAMPLE_RATE = 16000
+HOP_SIZE = 320
+FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE)
+
+DEFAULT_STEPS_PER_SECOND = 100
+DEFAULT_MAX_SHIFT_SECONDS = 10
+DEFAULT_NUM_VELOCITY_BINS = 1
+
+SLAKH_CLASS_PROGRAMS = {
+ "Acoustic Piano": 0,
+ "Electric Piano": 4,
+ "Chromatic Percussion": 8,
+ "Organ": 16,
+ "Acoustic Guitar": 24,
+ "Clean Electric Guitar": 26,
+ "Distorted Electric Guitar": 29,
+ "Acoustic Bass": 32,
+ "Electric Bass": 33,
+ "Violin": 40,
+ "Viola": 41,
+ "Cello": 42,
+ "Contrabass": 43,
+ "Orchestral Harp": 46,
+ "Timpani": 47,
+ "String Ensemble": 48,
+ "Synth Strings": 50,
+ "Choir and Voice": 52,
+ "Orchestral Hit": 55,
+ "Trumpet": 56,
+ "Trombone": 57,
+ "Tuba": 58,
+ "French Horn": 60,
+ "Brass Section": 61,
+ "Soprano/Alto Sax": 64,
+ "Tenor Sax": 66,
+ "Baritone Sax": 67,
+ "Oboe": 68,
+ "English Horn": 69,
+ "Bassoon": 70,
+ "Clarinet": 71,
+ "Pipe": 73,
+ "Synth Lead": 80,
+ "Synth Pad": 88,
+}
+
+
+@dataclasses.dataclass
+class NoteRepresentationConfig:
+ """Configuration note representations."""
+
+ onsets_only: bool
+ include_ties: bool
+
+
+@dataclasses.dataclass
+class NoteEventData:
+ pitch: int
+ velocity: Optional[int] = None
+ program: Optional[int] = None
+ is_drum: Optional[bool] = None
+ instrument: Optional[int] = None
+
+
+@dataclasses.dataclass
+class NoteEncodingState:
+ """Encoding state for note transcription, keeping track of active pitches."""
+
+ # velocity bin for active pitches and programs
+ active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict)
+
+
+@dataclasses.dataclass
+class EventRange:
+ type: str
+ min_value: int
+ max_value: int
+
+
+@dataclasses.dataclass
+class Event:
+ type: str
+ value: int
+
+
+class Tokenizer:
+ def __init__(self, regular_ids: int):
+ # The special tokens: 0=PAD, 1=EOS, and 2=UNK
+ self._num_special_tokens = 3
+ self._num_regular_tokens = regular_ids
+
+ def encode(self, token_ids):
+ encoded = []
+ for token_id in token_ids:
+ if not 0 <= token_id < self._num_regular_tokens:
+ raise ValueError(
+ f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})"
+ )
+ encoded.append(token_id + self._num_special_tokens)
+
+ # Add EOS token
+ encoded.append(1)
+
+ # Pad to till INPUT_FEATURE_LENGTH
+ encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded))
+
+ return encoded
+
+
+class Codec:
+ """Encode and decode events.
+
+ Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from
+ Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not
+ include things like EOS or UNK token handling.
+
+ To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required
+ and specified separately.
+ """
+
+ def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]):
+ """Define Codec.
+
+ Args:
+ max_shift_steps: Maximum number of shift steps that can be encoded.
+ steps_per_second: Shift steps will be interpreted as having a duration of
+ 1 / steps_per_second.
+ event_ranges: Other supported event types and their ranges.
+ """
+ self.steps_per_second = steps_per_second
+ self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps)
+ self._event_ranges = [self._shift_range] + event_ranges
+ # Ensure all event types have unique names.
+ assert len(self._event_ranges) == len({er.type for er in self._event_ranges})
+
+ @property
+ def num_classes(self) -> int:
+ return sum(er.max_value - er.min_value + 1 for er in self._event_ranges)
+
+ # The next couple methods are simplified special case methods just for shift
+ # events that are intended to be used from within autograph functions.
+
+ def is_shift_event_index(self, index: int) -> bool:
+ return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value)
+
+ @property
+ def max_shift_steps(self) -> int:
+ return self._shift_range.max_value
+
+ def encode_event(self, event: Event) -> int:
+ """Encode an event to an index."""
+ offset = 0
+ for er in self._event_ranges:
+ if event.type == er.type:
+ if not er.min_value <= event.value <= er.max_value:
+ raise ValueError(
+ f"Event value {event.value} is not within valid range "
+ f"[{er.min_value}, {er.max_value}] for type {event.type}"
+ )
+ return offset + event.value - er.min_value
+ offset += er.max_value - er.min_value + 1
+
+ raise ValueError(f"Unknown event type: {event.type}")
+
+ def event_type_range(self, event_type: str) -> Tuple[int, int]:
+ """Return [min_id, max_id] for an event type."""
+ offset = 0
+ for er in self._event_ranges:
+ if event_type == er.type:
+ return offset, offset + (er.max_value - er.min_value)
+ offset += er.max_value - er.min_value + 1
+
+ raise ValueError(f"Unknown event type: {event_type}")
+
+ def decode_event_index(self, index: int) -> Event:
+ """Decode an event index to an Event."""
+ offset = 0
+ for er in self._event_ranges:
+ if offset <= index <= offset + er.max_value - er.min_value:
+ return Event(type=er.type, value=er.min_value + index - offset)
+ offset += er.max_value - er.min_value + 1
+
+ raise ValueError(f"Unknown event index: {index}")
+
+
+@dataclasses.dataclass
+class ProgramGranularity:
+ # both tokens_map_fn and program_map_fn should be idempotent
+ tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]]
+ program_map_fn: Callable[[int], int]
+
+
+def drop_programs(tokens, codec: Codec):
+ """Drops program change events from a token sequence."""
+ min_program_id, max_program_id = codec.event_type_range("program")
+ return tokens[(tokens < min_program_id) | (tokens > max_program_id)]
+
+
+def programs_to_midi_classes(tokens, codec):
+ """Modifies program events to be the first program in the MIDI class."""
+ min_program_id, max_program_id = codec.event_type_range("program")
+ is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
+ return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens)
+
+
+PROGRAM_GRANULARITIES = {
+ # "flat" granularity; drop program change tokens and set NoteSequence
+ # programs to zero
+ "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0),
+ # map each program to the first program in its MIDI class
+ "midi_class": ProgramGranularity(
+ tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8)
+ ),
+ # leave programs as is
+ "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program),
+}
+
+
+def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1):
+ """
+ equivalent of tf.signal.frame
+ """
+ signal_length = signal.shape[axis]
+ if pad_end:
+ frames_overlap = frame_length - frame_step
+ rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap)
+ pad_size = int(frame_length - rest_samples)
+
+ if pad_size != 0:
+ pad_axis = [0] * signal.ndim
+ pad_axis[axis] = pad_size
+ signal = F.pad(signal, pad_axis, "constant", pad_value)
+ frames = signal.unfold(axis, frame_length, frame_step)
+ return frames
+
+
+def program_to_slakh_program(program):
+ # this is done very hackily, probably should use a custom mapping
+ for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True):
+ if program >= slakh_program:
+ return slakh_program
+
+
+def audio_to_frames(
+ samples,
+ hop_size: int,
+ frame_rate: int,
+) -> Tuple[Sequence[Sequence[int]], torch.Tensor]:
+ """Convert audio samples to non-overlapping frames and frame times."""
+ frame_size = hop_size
+ samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant")
+
+ # Split audio into frames.
+ frames = frame(
+ torch.Tensor(samples).unsqueeze(0),
+ frame_length=frame_size,
+ frame_step=frame_size,
+ pad_end=False, # TODO check why its off by 1 here when True
+ )
+
+ num_frames = len(samples) // frame_size
+
+ times = np.arange(num_frames) / frame_rate
+ return frames, times
+
+
+def note_sequence_to_onsets_and_offsets_and_programs(
+ ns: note_seq.NoteSequence,
+) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
+ """Extract onset & offset times and pitches & programs from a NoteSequence.
+
+ The onset & offset times will not necessarily be in sorted order.
+
+ Args:
+ ns: NoteSequence from which to extract onsets and offsets.
+
+ Returns:
+ times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for
+ note
+ offsets.
+ """
+ # Sort by program and pitch and put offsets before onsets as a tiebreaker for
+ # subsequent stable sort.
+ notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch))
+ times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes]
+ values = [
+ NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False)
+ for note in notes
+ if not note.is_drum
+ ] + [
+ NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum)
+ for note in notes
+ ]
+ return times, values
+
+
+def num_velocity_bins_from_codec(codec: Codec):
+ """Get number of velocity bins from event codec."""
+ lo, hi = codec.event_type_range("velocity")
+ return hi - lo
+
+
+# segment an array into segments of length n
+def segment(a, n):
+ return [a[i : i + n] for i in range(0, len(a), n)]
+
+
+def velocity_to_bin(velocity, num_velocity_bins):
+ if velocity == 0:
+ return 0
+ else:
+ return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)
+
+
+def note_event_data_to_events(
+ state: Optional[NoteEncodingState],
+ value: NoteEventData,
+ codec: Codec,
+) -> Sequence[Event]:
+ """Convert note event data to a sequence of events."""
+ if value.velocity is None:
+ # onsets only, no program or velocity
+ return [Event("pitch", value.pitch)]
+ else:
+ num_velocity_bins = num_velocity_bins_from_codec(codec)
+ velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins)
+ if value.program is None:
+ # onsets + offsets + velocities only, no programs
+ if state is not None:
+ state.active_pitches[(value.pitch, 0)] = velocity_bin
+ return [Event("velocity", velocity_bin), Event("pitch", value.pitch)]
+ else:
+ if value.is_drum:
+ # drum events use a separate vocabulary
+ return [Event("velocity", velocity_bin), Event("drum", value.pitch)]
+ else:
+ # program + velocity + pitch
+ if state is not None:
+ state.active_pitches[(value.pitch, value.program)] = velocity_bin
+ return [
+ Event("program", value.program),
+ Event("velocity", velocity_bin),
+ Event("pitch", value.pitch),
+ ]
+
+
+def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]:
+ """Output program and pitch events for active notes plus a final tie event."""
+ events = []
+ for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]):
+ if state.active_pitches[(pitch, program)]:
+ events += [Event("program", program), Event("pitch", pitch)]
+ events.append(Event("tie", 0))
+ return events
+
+
+def encode_and_index_events(
+ state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None
+):
+ """Encode a sequence of timed events and index to audio frame times.
+
+ Encodes time shifts as repeated single step shifts for later run length encoding.
+
+ Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio
+ frame. This can be used e.g. to prepend events representing the current state to a targets segment.
+
+ Args:
+ state: Initial event encoding state.
+ event_times: Sequence of event times.
+ event_values: Sequence of event values.
+ encode_event_fn: Function that transforms event value into a sequence of one
+ or more Event objects.
+ codec: An Codec object that maps Event objects to indices.
+ frame_times: Time for every audio frame.
+ encoding_state_to_events_fn: Function that transforms encoding state into a
+ sequence of one or more Event objects.
+
+ Returns:
+ events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame.
+ Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes
+ splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of
+ another.
+ event_end_indices: Corresponding end event index for every audio frame. Used
+ to ensure when slicing that one chunk ends where the next begins. Should always be true that
+ event_end_indices[i] = event_start_indices[i + 1].
+ state_events: Encoded "state" events representing the encoding state before
+ each event.
+ state_event_indices: Corresponding state event index for every audio frame.
+ """
+ indices = np.argsort(event_times, kind="stable")
+ event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices]
+ event_values = [event_values[i] for i in indices]
+
+ events = []
+ state_events = []
+ event_start_indices = []
+ state_event_indices = []
+
+ cur_step = 0
+ cur_event_idx = 0
+ cur_state_event_idx = 0
+
+ def fill_event_start_indices_to_cur_step():
+ while (
+ len(event_start_indices) < len(frame_times)
+ and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second
+ ):
+ event_start_indices.append(cur_event_idx)
+ state_event_indices.append(cur_state_event_idx)
+
+ for event_step, event_value in zip(event_steps, event_values):
+ while event_step > cur_step:
+ events.append(codec.encode_event(Event(type="shift", value=1)))
+ cur_step += 1
+ fill_event_start_indices_to_cur_step()
+ cur_event_idx = len(events)
+ cur_state_event_idx = len(state_events)
+ if encoding_state_to_events_fn:
+ # Dump state to state events *before* processing the next event, because
+ # we want to capture the state prior to the occurrence of the event.
+ for e in encoding_state_to_events_fn(state):
+ state_events.append(codec.encode_event(e))
+
+ for e in encode_event_fn(state, event_value, codec):
+ events.append(codec.encode_event(e))
+
+ # After the last event, continue filling out the event_start_indices array.
+ # The inequality is not strict because if our current step lines up exactly
+ # with (the start of) an audio frame, we need to add an additional shift event
+ # to "cover" that frame.
+ while cur_step / codec.steps_per_second <= frame_times[-1]:
+ events.append(codec.encode_event(Event(type="shift", value=1)))
+ cur_step += 1
+ fill_event_start_indices_to_cur_step()
+ cur_event_idx = len(events)
+
+ # Now fill in event_end_indices. We need this extra array to make sure that
+ # when we slice events, each slice ends exactly where the subsequent slice
+ # begins.
+ event_end_indices = event_start_indices[1:] + [len(events)]
+
+ events = np.array(events).astype(np.int32)
+ state_events = np.array(state_events).astype(np.int32)
+ event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
+ event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
+ state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
+
+ outputs = []
+ for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices):
+ outputs.append(
+ {
+ "inputs": events,
+ "event_start_indices": start_indices,
+ "event_end_indices": end_indices,
+ "state_events": state_events,
+ "state_event_indices": event_indices,
+ }
+ )
+
+ return outputs
+
+
+def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"):
+ """Extract target sequence corresponding to audio token segment."""
+ features = features.copy()
+ start_idx = features["event_start_indices"][0]
+ end_idx = features["event_end_indices"][-1]
+
+ features[feature_key] = features[feature_key][start_idx:end_idx]
+
+ if state_events_end_token is not None:
+ # Extract the state events corresponding to the audio start token, and
+ # prepend them to the targets array.
+ state_event_start_idx = features["state_event_indices"][0]
+ state_event_end_idx = state_event_start_idx + 1
+ while features["state_events"][state_event_end_idx - 1] != state_events_end_token:
+ state_event_end_idx += 1
+ features[feature_key] = np.concatenate(
+ [
+ features["state_events"][state_event_start_idx:state_event_end_idx],
+ features[feature_key],
+ ],
+ axis=0,
+ )
+
+ return features
+
+
+def map_midi_programs(
+ feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs"
+) -> Mapping[str, Any]:
+ """Apply MIDI program map to token sequences."""
+ granularity = PROGRAM_GRANULARITIES[granularity_type]
+
+ feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec)
+ return feature
+
+
+def run_length_encode_shifts_fn(
+ features,
+ codec: Codec,
+ feature_key: str = "inputs",
+ state_change_event_types: Sequence[str] = (),
+) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
+ """Return a function that run-length encodes shifts for a given codec.
+
+ Args:
+ codec: The Codec to use for shift events.
+ feature_key: The feature key for which to run-length encode shifts.
+ state_change_event_types: A list of event types that represent state
+ changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones
+ will be removed.
+
+ Returns:
+ A preprocessing function that run-length encodes single-step shifts.
+ """
+ state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types]
+
+ def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]:
+ """Combine leading/interior shifts, trim trailing shifts.
+
+ Args:
+ features: Dict of features to process.
+
+ Returns:
+ A dict of features.
+ """
+ events = features[feature_key]
+
+ shift_steps = 0
+ total_shift_steps = 0
+ output = np.array([], dtype=np.int32)
+
+ current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32)
+
+ for event in events:
+ if codec.is_shift_event_index(event):
+ shift_steps += 1
+ total_shift_steps += 1
+
+ else:
+ # If this event is a state change and has the same value as the current
+ # state, we can skip it entirely.
+ is_redundant = False
+ for i, (min_index, max_index) in enumerate(state_change_event_ranges):
+ if (min_index <= event) and (event <= max_index):
+ if current_state[i] == event:
+ is_redundant = True
+ current_state[i] = event
+ if is_redundant:
+ continue
+
+ # Once we've reached a non-shift event, RLE all previous shift events
+ # before outputting the non-shift event.
+ if shift_steps > 0:
+ shift_steps = total_shift_steps
+ while shift_steps > 0:
+ output_steps = np.minimum(codec.max_shift_steps, shift_steps)
+ output = np.concatenate([output, [output_steps]], axis=0)
+ shift_steps -= output_steps
+ output = np.concatenate([output, [event]], axis=0)
+
+ features[feature_key] = output
+ return features
+
+ return run_length_encode_shifts(features)
+
+
+def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig):
+ tie_token = codec.encode_event(Event("tie", 0))
+ state_events_end_token = tie_token if note_representation_config.include_ties else None
+
+ features = extract_sequence_with_indices(
+ features, state_events_end_token=state_events_end_token, feature_key="inputs"
+ )
+
+ features = map_midi_programs(features, codec)
+
+ features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"])
+
+ return features
+
+
+class MidiProcessor:
+ def __init__(self):
+ self.codec = Codec(
+ max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND,
+ steps_per_second=DEFAULT_STEPS_PER_SECOND,
+ event_ranges=[
+ EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH),
+ EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS),
+ EventRange("tie", 0, 0),
+ EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM),
+ EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH),
+ ],
+ )
+ self.tokenizer = Tokenizer(self.codec.num_classes)
+ self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True)
+
+ def __call__(self, midi: Union[bytes, os.PathLike, str]):
+ if not isinstance(midi, bytes):
+ with open(midi, "rb") as f:
+ midi = f.read()
+
+ ns = note_seq.midi_to_note_sequence(midi)
+ ns_sus = note_seq.apply_sustain_control_changes(ns)
+
+ for note in ns_sus.notes:
+ if not note.is_drum:
+ note.program = program_to_slakh_program(note.program)
+
+ samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE))
+
+ _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE)
+ times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus)
+
+ events = encode_and_index_events(
+ state=NoteEncodingState(),
+ event_times=times,
+ event_values=values,
+ frame_times=frame_times,
+ codec=self.codec,
+ encode_event_fn=note_event_data_to_events,
+ encoding_state_to_events_fn=note_encoding_state_to_events,
+ )
+
+ events = [
+ note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events
+ ]
+ input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events]
+
+ return input_tokens
diff --git a/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py b/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..94eaa176f3e5a15f4065e78b4b7714fa8c51ca83
--- /dev/null
+++ b/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py
@@ -0,0 +1,86 @@
+# Copyright 2022 The Music Spectrogram Diffusion Authors.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from transformers.modeling_utils import ModuleUtilsMixin
+from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+
+
+class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
+ @register_to_config
+ def __init__(
+ self,
+ max_length: int,
+ vocab_size: int,
+ d_model: int,
+ dropout_rate: float,
+ num_layers: int,
+ num_heads: int,
+ d_kv: int,
+ d_ff: int,
+ feed_forward_proj: str,
+ is_decoder: bool = False,
+ ):
+ super().__init__()
+
+ self.token_embedder = nn.Embedding(vocab_size, d_model)
+
+ self.position_encoding = nn.Embedding(max_length, d_model)
+ self.position_encoding.weight.requires_grad = False
+
+ self.dropout_pre = nn.Dropout(p=dropout_rate)
+
+ t5config = T5Config(
+ vocab_size=vocab_size,
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ d_ff=d_ff,
+ dropout_rate=dropout_rate,
+ feed_forward_proj=feed_forward_proj,
+ is_decoder=is_decoder,
+ is_encoder_decoder=False,
+ )
+
+ self.encoders = nn.ModuleList()
+ for lyr_num in range(num_layers):
+ lyr = T5Block(t5config)
+ self.encoders.append(lyr)
+
+ self.layer_norm = T5LayerNorm(d_model)
+ self.dropout_post = nn.Dropout(p=dropout_rate)
+
+ def forward(self, encoder_input_tokens, encoder_inputs_mask):
+ x = self.token_embedder(encoder_input_tokens)
+
+ seq_length = encoder_input_tokens.shape[1]
+ inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device)
+ x += self.position_encoding(inputs_positions)
+
+ x = self.dropout_pre(x)
+
+ # inverted the attention mask
+ input_shape = encoder_input_tokens.size()
+ extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
+
+ for lyr in self.encoders:
+ x = lyr(x, extended_attention_mask)[0]
+ x = self.layer_norm(x)
+
+ return self.dropout_post(x), encoder_inputs_mask
diff --git a/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..66155ebf7f35cbe224bf21fd54c47f3b5ee32a37
--- /dev/null
+++ b/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
@@ -0,0 +1,210 @@
+# Copyright 2022 The Music Spectrogram Diffusion Authors.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...models import T5FilmDecoder
+from ...schedulers import DDPMScheduler
+from ...utils import is_onnx_available, logging, randn_tensor
+
+
+if is_onnx_available():
+ from ..onnx_utils import OnnxRuntimeModel
+
+from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+from .continous_encoder import SpectrogramContEncoder
+from .notes_encoder import SpectrogramNotesEncoder
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+TARGET_FEATURE_LENGTH = 256
+
+
+class SpectrogramDiffusionPipeline(DiffusionPipeline):
+ _optional_components = ["melgan"]
+
+ def __init__(
+ self,
+ notes_encoder: SpectrogramNotesEncoder,
+ continuous_encoder: SpectrogramContEncoder,
+ decoder: T5FilmDecoder,
+ scheduler: DDPMScheduler,
+ melgan: OnnxRuntimeModel if is_onnx_available() else Any,
+ ) -> None:
+ super().__init__()
+
+ # From MELGAN
+ self.min_value = math.log(1e-5) # Matches MelGAN training.
+ self.max_value = 4.0 # Largest value for most examples
+ self.n_dims = 128
+
+ self.register_modules(
+ notes_encoder=notes_encoder,
+ continuous_encoder=continuous_encoder,
+ decoder=decoder,
+ scheduler=scheduler,
+ melgan=melgan,
+ )
+
+ def scale_features(self, features, output_range=(-1.0, 1.0), clip=False):
+ """Linearly scale features to network outputs range."""
+ min_out, max_out = output_range
+ if clip:
+ features = torch.clip(features, self.min_value, self.max_value)
+ # Scale to [0, 1].
+ zero_one = (features - self.min_value) / (self.max_value - self.min_value)
+ # Scale to [min_out, max_out].
+ return zero_one * (max_out - min_out) + min_out
+
+ def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False):
+ """Invert by linearly scaling network outputs to features range."""
+ min_out, max_out = input_range
+ outputs = torch.clip(outputs, min_out, max_out) if clip else outputs
+ # Scale to [0, 1].
+ zero_one = (outputs - min_out) / (max_out - min_out)
+ # Scale to [self.min_value, self.max_value].
+ return zero_one * (self.max_value - self.min_value) + self.min_value
+
+ def encode(self, input_tokens, continuous_inputs, continuous_mask):
+ tokens_mask = input_tokens > 0
+ tokens_encoded, tokens_mask = self.notes_encoder(
+ encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask
+ )
+
+ continuous_encoded, continuous_mask = self.continuous_encoder(
+ encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask
+ )
+
+ return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)]
+
+ def decode(self, encodings_and_masks, input_tokens, noise_time):
+ timesteps = noise_time
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(input_tokens.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ logits = self.decoder(
+ encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps
+ )
+ return logits
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_tokens: List[List[int]],
+ generator: Optional[torch.Generator] = None,
+ num_inference_steps: int = 100,
+ return_dict: bool = True,
+ output_type: str = "numpy",
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ) -> Union[AudioPipelineOutput, Tuple]:
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32)
+ full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32)
+ ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device)
+
+ for i, encoder_input_tokens in enumerate(input_tokens):
+ if i == 0:
+ encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to(
+ device=self.device, dtype=self.decoder.dtype
+ )
+ # The first chunk has no previous context.
+ encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device)
+ else:
+ # The full song pipeline does not feed in a context feature, so the mask
+ # will be all 0s after the feature converter. Because we know we're
+ # feeding in a full context chunk from the previous prediction, set it
+ # to all 1s.
+ encoder_continuous_mask = ones
+
+ encoder_continuous_inputs = self.scale_features(
+ encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True
+ )
+
+ encodings_and_masks = self.encode(
+ input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device),
+ continuous_inputs=encoder_continuous_inputs,
+ continuous_mask=encoder_continuous_mask,
+ )
+
+ # Sample encoder_continuous_inputs shaped gaussian noise to begin loop
+ x = randn_tensor(
+ shape=encoder_continuous_inputs.shape,
+ generator=generator,
+ device=self.device,
+ dtype=self.decoder.dtype,
+ )
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Denoising diffusion loop
+ for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ output = self.decode(
+ encodings_and_masks=encodings_and_masks,
+ input_tokens=x,
+ noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1)
+ )
+
+ # Compute previous output: x_t -> x_t-1
+ x = self.scheduler.step(output, t, x, generator=generator).prev_sample
+
+ mel = self.scale_to_features(x, input_range=[-1.0, 1.0])
+ encoder_continuous_inputs = mel[:1]
+ pred_mel = mel.cpu().float().numpy()
+
+ full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1)
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, full_pred_mel)
+
+ logger.info("Generated segment", i)
+
+ if output_type == "numpy" and not is_onnx_available():
+ raise ValueError(
+ "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'."
+ )
+ elif output_type == "numpy" and self.melgan is None:
+ raise ValueError(
+ "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'."
+ )
+
+ if output_type == "numpy":
+ output = self.melgan(input_features=full_pred_mel.astype(np.float32))
+ else:
+ output = full_pred_mel
+
+ if not return_dict:
+ return (output,)
+
+ return AudioPipelineOutput(audios=output)
diff --git a/diffusers/pipelines/stable_diffusion/README.md b/diffusers/pipelines/stable_diffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..66df9a811afbf70a5e943ed1a1e3e7c6955e6c25
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/README.md
@@ -0,0 +1,176 @@
+# Stable Diffusion
+
+## Overview
+
+Stable Diffusion was proposed in [Stable Diffusion Announcement](https://stability.ai/blog/stable-diffusion-announcement) by Patrick Esser and Robin Rombach and the Stability AI team.
+
+The summary of the model is the following:
+
+*Stable Diffusion is a text-to-image model that will empower billions of people to create stunning art within seconds. It is a breakthrough in speed and quality meaning that it can run on consumer GPUs. You can see some of the amazing output that has been created by this model without pre or post-processing on this page. The model itself builds upon the work of the team at CompVis and Runway in their widely used latent diffusion model combined with insights from the conditional diffusion models by our lead generative AI developer Katherine Crowson, Dall-E 2 by Open AI, Imagen by Google Brain and many others. We are delighted that AI media generation is a cooperative field and hope it can continue this way to bring the gift of creativity to all.*
+
+## Tips:
+
+- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model.
+- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion).
+- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can
+download the weights with `git lfs install; git clone https://huggingface.co/runwayml/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below.
+- Stable Diffusion can work with a variety of different samplers as is shown below.
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Colab
+|---|---|:---:|
+| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
+
+## Examples:
+
+### Using Stable Diffusion without being logged into the Hub.
+
+If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`.
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+```
+
+This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`:
+
+```
+git lfs install
+git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+and simply passing the local path to `from_pretrained`:
+
+```python
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
+```
+
+### Text-to-Image with default PLMS scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Text-to-Image with DDIM scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, DDIMScheduler
+
+scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ scheduler=scheduler,
+).to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Text-to-Image with K-LMS scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
+
+lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ scheduler=lms,
+).to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### CycleDiffusion using Stable Diffusion and DDIM scheduler
+
+```python
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+from diffusers import CycleDiffusionPipeline, DDIMScheduler
+
+
+# load the scheduler. CycleDiffusion only supports stochastic schedulers.
+
+# load the pipeline
+# make sure you're logged in with `huggingface-cli login`
+model_id_or_path = "CompVis/stable-diffusion-v1-4"
+scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
+pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
+
+# let's download an initial image
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+init_image.save("horse.png")
+
+# let's specify a prompt
+source_prompt = "An astronaut riding a horse"
+prompt = "An astronaut riding an elephant"
+
+# call the pipeline
+image = pipe(
+ prompt=prompt,
+ source_prompt=source_prompt,
+ image=init_image,
+ num_inference_steps=100,
+ eta=0.1,
+ strength=0.8,
+ guidance_scale=2,
+ source_guidance_scale=1,
+).images[0]
+
+image.save("horse_to_elephant.png")
+
+# let's try another example
+# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+init_image.save("black.png")
+
+source_prompt = "A black colored car"
+prompt = "A blue colored car"
+
+# call the pipeline
+torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ source_prompt=source_prompt,
+ image=init_image,
+ num_inference_steps=100,
+ eta=0.1,
+ strength=0.85,
+ guidance_scale=3,
+ source_guidance_scale=1,
+).images[0]
+
+image.save("black_to_blue.png")
+```
diff --git a/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33ab05a1dacbdfdfc02966675de4c30cb1069a10
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,136 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from ...utils import (
+ BaseOutput,
+ OptionalDependencyNotAvailable,
+ is_flax_available,
+ is_k_diffusion_available,
+ is_k_diffusion_version,
+ is_onnx_available,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_cycle_diffusion import CycleDiffusionPipeline
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
+ from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
+ from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+ from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
+ from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
+ from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
+ from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
+ from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
+ from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
+ from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
+ from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
+ from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
+ from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
+ from .pipeline_stable_unclip import StableUnCLIPPipeline
+ from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
+ from .safety_checker import StableDiffusionSafetyChecker
+ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
+else:
+ from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
+
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import (
+ StableDiffusionDepth2ImgPipeline,
+ StableDiffusionDiffEditPipeline,
+ StableDiffusionPix2PixZeroPipeline,
+ )
+else:
+ from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
+ from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
+ from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
+
+
+try:
+ if not (
+ is_torch_available()
+ and is_transformers_available()
+ and is_k_diffusion_available()
+ and is_k_diffusion_version(">=", "0.0.12")
+ ):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
+else:
+ from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
+
+try:
+ if not (is_transformers_available() and is_onnx_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_onnx_objects import * # noqa F403
+else:
+ from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
+ from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
+ from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
+
+if is_transformers_available() and is_flax_available():
+ import flax
+
+ @flax.struct.dataclass
+ class FlaxStableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`np.ndarray`)
+ Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: np.ndarray
+ nsfw_content_detected: List[bool]
+
+ from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
+ from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
+ from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
+ from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
+ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fe8a5e86fcb9abc1e5bd24979a297e5f4a61fa0
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be6e5dbc542980a2a9c181d51c0d121b8c5daaad
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1efdce6f5f17dd9e96e237838121bcd20901ea2
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..310f483c52a4c2b5764a000bd38fc64159054a9e
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23cdaf1ea18d26d34e21833b6a4613ab3da8d560
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e954a2539683f45e4550487277122d0c8d1016d6
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..740e970b3f04f06fa677d743bdab640888b6a15a
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..265fccce7fc35213d4534435fe2970ae7ebbcf80
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b42d5b10821b289f8d5c5bbf62ec1a2d63f65bc9
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a56995c159ce88c7b6dabd94a89ecb2f0e464b9
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4f7c973124b9ebbf8caaf31dd0286aacd58a984
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e957f1acc09ad47c59f56f41373a75ac2552cc1
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..469564dbfc6d3d22c9335406beefafae06b100ca
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_diffedit.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..174e4b0511c6a12cdeaab0cc3c662f2a88880958
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd2779536ca92718804bbaee0682a90e59e78aa9
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..059cb08407edd44cc9e259a560ae464a3a718d84
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a3598f0188cd4a18e09a66212ca3ec4e78d5e42
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56cc8a71f4ba495052576d8c799e287cee41284e
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67f8a7a60eb8165ba4c449fbaa9375a865145d15
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10a87881071f44f32b23e4b83469b180a190ad8a
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b79e5fe13c882f24ab56d5d9a3e7bb0956e427e
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f902de3b0e0f270d96292fd0fd2ed9c178cef09
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b817d0eac72a93b73a2a3428c998c55f9942c97d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92e8e5a1a31f0b84ac17f2f1d01414c732cd490d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..054340af5112103fcc31a446100f38def1963177
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41744f8c3f44e67fb16d9f66e8e2b1e6638f128f
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40b0af5da41ff585583cae35582084ce525ff9ac
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d6e5409c2d52c75d344b7fb9e650a0e8688e751
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67b7cc10ebbdcfe5dec86e1e015d6e5a9a6ae857
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_model_editing.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dee35d3c90bf3c2122fa9a10ecc0429453f1bd5f
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86a259b57ad5f316a50655c02f5345fe3a988736
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac152a1991c3a55335aef4ee3332affc76d0a87f
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9d83f26ffeb9e48f89b8e3a8560bf49cbdcade4
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_paradigms.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78aa983eef72b178bc75eb7748e8ff2505e04cdc
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82865fe2942026cdd3e092f2d423318bc8c533ed
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..062381426a02ffc57f12661de692aeed3ced1379
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffee4e5c133f3bbb624db8ee208858ea07c20bd3
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f63e62ca11890edd63dd1427644a6164a2119b55
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6e99169784cd2d8c4c1ce7e249b02ef72418612
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16b4297d4174ad8a978699b8413e6c2868b73dd3
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efefb76477b66d6756abb0ee9c2401ae5d21ef29
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac8e3d20e59e4ad36c04dd8efb8117689263dd3d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d538ed027c556579fb4e02653da477bfe2c98ecc
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4790747846014eb8c3c861c015f91b8644ea3de3
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccb27ea686000dc76cdae7d11216b032ac24193e
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80ed3f0730f15935eefe5178d612d7b5f32016f8
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb913e81c25ea5814215460d4fd8923636f7767e
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eeb80f12dfc9a10f0479ce711ee85869d839dd5
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -0,0 +1,1636 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Conversion script for the Stable Diffusion checkpoints."""
+
+import re
+from contextlib import nullcontext
+from io import BytesIO
+from typing import Optional
+
+import requests
+import torch
+from transformers import (
+ AutoFeatureExtractor,
+ BertTokenizerFast,
+ CLIPImageProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+)
+
+from ...models import (
+ AutoencoderKL,
+ ControlNetModel,
+ PriorTransformer,
+ UNet2DConditionModel,
+)
+from ...schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UnCLIPScheduler,
+)
+from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
+from ...utils.import_utils import BACKENDS_MAPPING
+from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
+from ..paint_by_example import PaintByExampleImageEncoder
+from ..pipeline_utils import DiffusionPipeline
+from .safety_checker import StableDiffusionSafetyChecker
+from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+ from accelerate.utils import set_module_tensor_to_device
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
+
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
+
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if controlnet:
+ unet_params = original_config.model.params.control_stage_config.params
+ else:
+ if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
+ unet_params = original_config.model.params.unet_config.params
+ else:
+ unet_params = original_config.model.params.network_config.params
+
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
+
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ if unet_params.transformer_depth is not None:
+ transformer_layers_per_block = (
+ unet_params.transformer_depth
+ if isinstance(unet_params.transformer_depth, int)
+ else list(unet_params.transformer_depth)
+ )
+ else:
+ transformer_layers_per_block = 1
+
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
+
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
+ head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
+
+ class_embed_type = None
+ addition_embed_type = None
+ addition_time_embed_dim = None
+ projection_class_embeddings_input_dim = None
+ context_dim = None
+
+ if unet_params.context_dim is not None:
+ context_dim = (
+ unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
+ )
+
+ if "num_classes" in unet_params:
+ if unet_params.num_classes == "sequential":
+ if context_dim in [2048, 1280]:
+ # SDXL
+ addition_embed_type = "text_time"
+ addition_time_embed_dim = 256
+ else:
+ class_embed_type = "projection"
+ assert "adm_in_channels" in unet_params
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
+ else:
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params.num_res_blocks,
+ "cross_attention_dim": context_dim,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "addition_embed_type": addition_embed_type,
+ "addition_time_embed_dim": addition_time_embed_dim,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "transformer_layers_per_block": transformer_layers_per_block,
+ }
+
+ if controlnet:
+ config["conditioning_channels"] = unet_params.hint_channels
+ else:
+ config["out_channels"] = unet_params.out_channels
+ config["up_block_types"] = tuple(up_block_types)
+
+ return config
+
+
+def create_vae_diffusers_config(original_config, image_size: int):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
+ _ = original_config.model.params.first_stage_config.params.embed_dim
+
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params.in_channels,
+ "out_channels": vae_params.out_ch,
+ "down_block_types": tuple(down_block_types),
+ "up_block_types": tuple(up_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "latent_channels": vae_params.z_channels,
+ "layers_per_block": vae_params.num_res_blocks,
+ }
+ return config
+
+
+def create_diffusers_schedular(original_config):
+ schedular = DDIMScheduler(
+ num_train_timesteps=original_config.model.params.timesteps,
+ beta_start=original_config.model.params.linear_start,
+ beta_end=original_config.model.params.linear_end,
+ beta_schedule="scaled_linear",
+ )
+ return schedular
+
+
+def create_ldm_bert_config(original_config):
+ bert_params = original_config.model.parms.cond_stage_config.params
+ config = LDMBertConfig(
+ d_model=bert_params.n_embed,
+ encoder_layers=bert_params.n_layer,
+ encoder_ffn_dim=bert_params.n_embed * 4,
+ )
+ return config
+
+
+def convert_ldm_unet_checkpoint(
+ checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
+):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ if skip_extract_state_dict:
+ unet_state_dict = checkpoint
+ else:
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ if controlnet:
+ unet_key = "control_model."
+ else:
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ logger.warning(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ logger.warning(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ if config["addition_embed_type"] == "text_time":
+ new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ if not controlnet:
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ if controlnet:
+ # conditioning embedding
+
+ orig_index = 0
+
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ orig_index += 2
+
+ diffusers_index = 0
+
+ while diffusers_index < 6:
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+ diffusers_index += 1
+ orig_index += 2
+
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ # down blocks
+ for i in range(num_input_blocks):
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
+
+ # mid block
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
+
+ return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ keys = list(checkpoint.keys())
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
+
+
+def convert_ldm_bert_checkpoint(checkpoint, config):
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
+
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
+
+ def _copy_linear(hf_linear, pt_linear):
+ hf_linear.weight = pt_linear.weight
+ hf_linear.bias = pt_linear.bias
+
+ def _copy_layer(hf_layer, pt_layer):
+ # copy layer norms
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
+
+ # copy attn
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
+
+ # copy MLP
+ pt_mlp = pt_layer[1][1]
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
+
+ def _copy_layers(hf_layers, pt_layers):
+ for i, hf_layer in enumerate(hf_layers):
+ if i != 0:
+ i += i
+ pt_layer = pt_layers[i : i + 2]
+ _copy_layer(hf_layer, pt_layer)
+
+ hf_model = LDMBertModel(config).eval()
+
+ # copy embeds
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
+
+ # copy layer norm
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
+
+ # copy hidden layers
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
+
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
+
+ return hf_model
+
+
+def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
+ if text_encoder is None:
+ config_name = "openai/clip-vit-large-patch14"
+ config = CLIPTextConfig.from_pretrained(config_name)
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ text_model = CLIPTextModel(config)
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
+
+ for key in keys:
+ for prefix in remove_prefixes:
+ if key.startswith(prefix):
+ text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
+
+ if is_accelerate_available():
+ for param_name, param in text_model_dict.items():
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
+ else:
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+textenc_conversion_lst = [
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
+ ("ln_final.weight", "text_model.final_layer_norm.weight"),
+ ("ln_final.bias", "text_model.final_layer_norm.bias"),
+ ("text_projection", "text_projection.weight"),
+]
+textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
+
+textenc_transformer_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+
+def convert_paint_by_example_checkpoint(checkpoint):
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
+ model = PaintByExampleImageEncoder(config)
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
+
+ # load clip vision
+ model.model.load_state_dict(text_model_dict)
+
+ # load mapper
+ keys_mapper = {
+ k[len("cond_stage_model.mapper.res") :]: v
+ for k, v in checkpoint.items()
+ if k.startswith("cond_stage_model.mapper")
+ }
+
+ MAPPING = {
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
+ "attn.c_proj": ["attn1.to_out.0"],
+ "ln_1": ["norm1"],
+ "ln_2": ["norm3"],
+ "mlp.c_fc": ["ff.net.0.proj"],
+ "mlp.c_proj": ["ff.net.2"],
+ }
+
+ mapped_weights = {}
+ for key, value in keys_mapper.items():
+ prefix = key[: len("blocks.i")]
+ suffix = key.split(prefix)[-1].split(".")[-1]
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
+ mapped_names = MAPPING[name]
+
+ num_splits = len(mapped_names)
+ for i, mapped_name in enumerate(mapped_names):
+ new_name = ".".join([prefix, mapped_name, suffix])
+ shape = value.shape[0] // num_splits
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
+
+ model.mapper.load_state_dict(mapped_weights)
+
+ # load final layer norm
+ model.final_layer_norm.load_state_dict(
+ {
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
+ }
+ )
+
+ # load final proj
+ model.proj_out.load_state_dict(
+ {
+ "bias": checkpoint["proj_out.bias"],
+ "weight": checkpoint["proj_out.weight"],
+ }
+ )
+
+ # load uncond vector
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
+ return model
+
+
+def convert_open_clip_checkpoint(
+ checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
+):
+ # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
+ # text_model = CLIPTextModelWithProjection.from_pretrained(
+ # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
+ # )
+ config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
+
+ keys = list(checkpoint.keys())
+
+ keys_to_ignore = []
+ if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
+ # make sure to remove all keys > 22
+ keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
+ keys_to_ignore += ["cond_stage_model.model.text_projection"]
+
+ text_model_dict = {}
+
+ if prefix + "text_projection" in checkpoint:
+ d_model = int(checkpoint[prefix + "text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
+
+ for key in keys:
+ if key in keys_to_ignore:
+ continue
+ if key[len(prefix) :] in textenc_conversion_map:
+ if key.endswith("text_projection"):
+ value = checkpoint[key].T
+ else:
+ value = checkpoint[key]
+
+ text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
+
+ if key.startswith(prefix + "transformer."):
+ new_key = key[len(prefix + "transformer.") :]
+ if new_key.endswith(".in_proj_weight"):
+ new_key = new_key[: -len(".in_proj_weight")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
+ elif new_key.endswith(".in_proj_bias"):
+ new_key = new_key[: -len(".in_proj_bias")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
+ else:
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+
+ text_model_dict[new_key] = checkpoint[key]
+
+ if is_accelerate_available():
+ for param_name, param in text_model_dict.items():
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
+ else:
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+def stable_unclip_image_encoder(original_config):
+ """
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
+
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
+ encoders.
+ """
+
+ image_embedder_config = original_config.model.params.embedder_config
+
+ sd_clip_image_embedder_class = image_embedder_config.target
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
+
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
+ clip_model_name = image_embedder_config.params.model
+
+ if clip_model_name == "ViT-L/14":
+ feature_extractor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ else:
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
+
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
+ feature_extractor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ else:
+ raise NotImplementedError(
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
+ )
+
+ return feature_extractor, image_encoder
+
+
+def stable_unclip_image_noising_components(
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
+):
+ """
+ Returns the noising components for the img2img and txt2img unclip pipelines.
+
+ Converts the stability noise augmentor into
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
+ 2. a `DDPMScheduler` for holding the noise schedule
+
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
+ """
+ noise_aug_config = original_config.model.params.noise_aug_config
+ noise_aug_class = noise_aug_config.target
+ noise_aug_class = noise_aug_class.split(".")[-1]
+
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
+ noise_aug_config = noise_aug_config.params
+ embedding_dim = noise_aug_config.timestep_dim
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
+
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
+
+ if "clip_stats_path" in noise_aug_config:
+ if clip_stats_path is None:
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
+
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
+ clip_mean = clip_mean[None, :]
+ clip_std = clip_std[None, :]
+
+ clip_stats_state_dict = {
+ "mean": clip_mean,
+ "std": clip_std,
+ }
+
+ image_normalizer.load_state_dict(clip_stats_state_dict)
+ else:
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
+
+ return image_normalizer, image_noising_scheduler
+
+
+def convert_controlnet_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=None,
+ cross_attention_dim=None,
+):
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
+ ctrlnet_config["upcast_attention"] = upcast_attention
+
+ ctrlnet_config.pop("sample_size")
+
+ if use_linear_projection is not None:
+ ctrlnet_config["use_linear_projection"] = use_linear_projection
+
+ if cross_attention_dim is not None:
+ ctrlnet_config["cross_attention_dim"] = cross_attention_dim
+
+ controlnet = ControlNetModel(**ctrlnet_config)
+
+ # Some controlnet ckpt files are distributed independently from the rest of the
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
+ if "time_embed.0.weight" in checkpoint:
+ skip_extract_state_dict = True
+ else:
+ skip_extract_state_dict = False
+
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
+ checkpoint,
+ ctrlnet_config,
+ path=checkpoint_path,
+ extract_ema=extract_ema,
+ controlnet=True,
+ skip_extract_state_dict=skip_extract_state_dict,
+ )
+
+ controlnet.load_state_dict(converted_ctrl_checkpoint)
+
+ return controlnet
+
+
+def download_from_original_stable_diffusion_ckpt(
+ checkpoint_path: str,
+ original_config_file: str = None,
+ image_size: Optional[int] = None,
+ prediction_type: str = None,
+ model_type: str = None,
+ extract_ema: bool = False,
+ scheduler_type: str = "pndm",
+ num_in_channels: Optional[int] = None,
+ upcast_attention: Optional[bool] = None,
+ device: str = None,
+ from_safetensors: bool = False,
+ stable_unclip: Optional[str] = None,
+ stable_unclip_prior: Optional[str] = None,
+ clip_stats_path: Optional[str] = None,
+ controlnet: Optional[bool] = None,
+ load_safety_checker: bool = True,
+ pipeline_class: DiffusionPipeline = None,
+ local_files_only=False,
+ vae_path=None,
+ text_encoder=None,
+ tokenizer=None,
+) -> DiffusionPipeline:
+ """
+ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
+ config file.
+
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
+
+ Args:
+ checkpoint_path (`str`): Path to `.ckpt` file.
+ original_config_file (`str`):
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
+ inferred by looking for a key that only exists in SD2.0 models.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
+ Base. Use 768 for Stable Diffusion v2.
+ prediction_type (`str`, *optional*):
+ The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
+ Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
+ num_in_channels (`int`, *optional*, defaults to None):
+ The number of input channels. If `None`, it will be automatically inferred.
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
+ "ddim"]`.
+ model_type (`str`, *optional*, defaults to `None`):
+ The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
+ "FrozenCLIPEmbedder", "PaintByExample"]`.
+ is_img2img (`bool`, *optional*, defaults to `False`):
+ Whether the model should be loaded as an img2img pipeline.
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
+ inference. Non-EMA weights are usually better to continue fine-tuning.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted. This is necessary when running stable
+ diffusion 2.1.
+ device (`str`, *optional*, defaults to `None`):
+ The device to use. Pass `None` to determine automatically.
+ from_safetensors (`str`, *optional*, defaults to `False`):
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether to load the safety checker or not. Defaults to `True`.
+ pipeline_class (`str`, *optional*, defaults to `None`):
+ The pipeline class to use. Pass `None` to determine automatically.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
+ An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
+ to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
+ An instance of
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
+ needed.
+ return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
+ """
+
+ # import pipelines here to avoid circular import error when using from_single_file method
+ from diffusers import (
+ LDMTextToImagePipeline,
+ PaintByExamplePipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ StableDiffusionXLImg2ImgPipeline,
+ StableUnCLIPImg2ImgPipeline,
+ StableUnCLIPPipeline,
+ )
+
+ if pipeline_class is None:
+ pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
+
+ if prediction_type == "v-prediction":
+ prediction_type = "v_prediction"
+
+ if not is_omegaconf_available():
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
+
+ from omegaconf import OmegaConf
+
+ if from_safetensors:
+ if not is_safetensors_available():
+ raise ValueError(BACKENDS_MAPPING["safetensors"][1])
+
+ from safetensors.torch import load_file as safe_load
+
+ checkpoint = safe_load(checkpoint_path, device="cpu")
+ else:
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ # Sometimes models don't have the global_step item
+ if "global_step" in checkpoint:
+ global_step = checkpoint["global_step"]
+ else:
+ logger.debug("global_step key not found in model")
+ global_step = None
+
+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ if original_config_file is None:
+ key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
+ key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
+ key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
+
+ # model_type = "v1"
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
+
+ if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
+ # model_type = "v2"
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
+
+ if global_step == 110000:
+ # v2.1 needs to upcast attention
+ upcast_attention = True
+ elif key_name_sd_xl_base in checkpoint:
+ # only base xl has two text embedders
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
+ elif key_name_sd_xl_refiner in checkpoint:
+ # only refiner xl has embedder and one text embedders
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
+
+ original_config_file = BytesIO(requests.get(config_url).content)
+
+ original_config = OmegaConf.load(original_config_file)
+
+ # Convert the text model.
+ if (
+ model_type is None
+ and "cond_stage_config" in original_config.model.params
+ and original_config.model.params.cond_stage_config is not None
+ ):
+ model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
+ logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
+ elif model_type is None and original_config.model.params.network_config is not None:
+ if original_config.model.params.network_config.params.context_dim == 2048:
+ model_type = "SDXL"
+ else:
+ model_type = "SDXL-Refiner"
+ if image_size is None:
+ image_size = 1024
+
+ if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
+ num_in_channels = 9
+ elif num_in_channels is None:
+ num_in_channels = 4
+
+ if "unet_config" in original_config.model.params:
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
+
+ if (
+ "parameterization" in original_config["model"]["params"]
+ and original_config["model"]["params"]["parameterization"] == "v"
+ ):
+ if prediction_type is None:
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
+ # as it relies on a brittle global step parameter here
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
+ if image_size is None:
+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
+ # as it relies on a brittle global step parameter here
+ image_size = 512 if global_step == 875000 else 768
+ else:
+ if prediction_type is None:
+ prediction_type = "epsilon"
+ if image_size is None:
+ image_size = 512
+
+ if controlnet is None:
+ controlnet = "control_stage_config" in original_config.model.params
+
+ controlnet = convert_controlnet_checkpoint(
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
+ )
+
+ num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
+
+ if model_type in ["SDXL", "SDXL-Refiner"]:
+ scheduler_dict = {
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "beta_end": 0.012,
+ "interpolation_type": "linear",
+ "num_train_timesteps": num_train_timesteps,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "set_alpha_to_one": False,
+ "skip_prk_steps": True,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+ }
+ scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
+ scheduler_type = "euler"
+ else:
+ beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
+ beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
+ scheduler = DDIMScheduler(
+ beta_end=beta_end,
+ beta_schedule="scaled_linear",
+ beta_start=beta_start,
+ num_train_timesteps=num_train_timesteps,
+ steps_offset=1,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ prediction_type=prediction_type,
+ )
+ # make sure scheduler works correctly with DDIM
+ scheduler.register_to_config(clip_sample=False)
+
+ if scheduler_type == "pndm":
+ config = dict(scheduler.config)
+ config["skip_prk_steps"] = True
+ scheduler = PNDMScheduler.from_config(config)
+ elif scheduler_type == "lms":
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "heun":
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "euler":
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "euler-ancestral":
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "dpm":
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
+ elif scheduler_type == "ddim":
+ scheduler = scheduler
+ else:
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
+
+ # Convert the UNet2DConditionModel model.
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
+ unet_config["upcast_attention"] = upcast_attention
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
+ )
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ unet = UNet2DConditionModel(**unet_config)
+
+ if is_accelerate_available():
+ for param_name, param in converted_unet_checkpoint.items():
+ set_module_tensor_to_device(unet, param_name, "cpu", value=param)
+ else:
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ # Convert the VAE model.
+ if vae_path is None:
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
+
+ if (
+ "model" in original_config
+ and "params" in original_config.model
+ and "scale_factor" in original_config.model.params
+ ):
+ vae_scaling_factor = original_config.model.params.scale_factor
+ else:
+ vae_scaling_factor = 0.18215 # default SD scaling factor
+
+ vae_config["scaling_factor"] = vae_scaling_factor
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ vae = AutoencoderKL(**vae_config)
+
+ if is_accelerate_available():
+ for param_name, param in converted_vae_checkpoint.items():
+ set_module_tensor_to_device(vae, param_name, "cpu", value=param)
+ else:
+ vae.load_state_dict(converted_vae_checkpoint)
+ else:
+ vae = AutoencoderKL.from_pretrained(vae_path)
+
+ if model_type == "FrozenOpenCLIPEmbedder":
+ config_name = "stabilityai/stable-diffusion-2"
+ config_kwargs = {"subfolder": "text_encoder"}
+
+ text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
+ tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
+
+ if stable_unclip is None:
+ if controlnet:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+ )
+ else:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+ )
+ else:
+ image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
+ original_config, clip_stats_path=clip_stats_path, device=device
+ )
+
+ if stable_unclip == "img2img":
+ feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
+
+ pipe = StableUnCLIPImg2ImgPipeline(
+ # image encoding components
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ # image noising components
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ # regular denoising components
+ tokenizer=tokenizer,
+ text_encoder=text_model,
+ unet=unet,
+ scheduler=scheduler,
+ # vae
+ vae=vae,
+ )
+ elif stable_unclip == "txt2img":
+ if stable_unclip_prior is None or stable_unclip_prior == "karlo":
+ karlo_model = "kakaobrain/karlo-v1-alpha"
+ prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
+
+ prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+ prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+
+ prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
+ prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+ else:
+ raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
+
+ pipe = StableUnCLIPPipeline(
+ # prior components
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+ # image noising components
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ # regular denoising components
+ tokenizer=tokenizer,
+ text_encoder=text_model,
+ unet=unet,
+ scheduler=scheduler,
+ # vae
+ vae=vae,
+ )
+ else:
+ raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
+ elif model_type == "PaintByExample":
+ vision_model = convert_paint_by_example_checkpoint(checkpoint)
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
+ pipe = PaintByExamplePipeline(
+ vae=vae,
+ image_encoder=vision_model,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=feature_extractor,
+ )
+ elif model_type == "FrozenCLIPEmbedder":
+ text_model = convert_ldm_clip_checkpoint(
+ checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
+ )
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
+
+ if load_safety_checker:
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
+ else:
+ safety_checker = None
+ feature_extractor = None
+
+ if controlnet:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ else:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ elif model_type in ["SDXL", "SDXL-Refiner"]:
+ if model_type == "SDXL":
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+ text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
+
+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+ config_kwargs = {"projection_dim": 1280}
+ text_encoder_2 = convert_open_clip_checkpoint(
+ checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
+ )
+
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ force_zeros_for_empty_prompt=True,
+ )
+ else:
+ tokenizer = None
+ text_encoder = None
+ tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
+
+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+ config_kwargs = {"projection_dim": 1280}
+ text_encoder_2 = convert_open_clip_checkpoint(
+ checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
+ )
+
+ pipe = StableDiffusionXLImg2ImgPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ requires_aesthetics_score=True,
+ force_zeros_for_empty_prompt=False,
+ )
+ else:
+ text_config = create_ldm_bert_config(original_config)
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ return pipe
+
+
+def download_controlnet_from_original_ckpt(
+ checkpoint_path: str,
+ original_config_file: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: Optional[int] = None,
+ upcast_attention: Optional[bool] = None,
+ device: str = None,
+ from_safetensors: bool = False,
+ use_linear_projection: Optional[bool] = None,
+ cross_attention_dim: Optional[bool] = None,
+) -> DiffusionPipeline:
+ if not is_omegaconf_available():
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
+
+ from omegaconf import OmegaConf
+
+ if from_safetensors:
+ if not is_safetensors_available():
+ raise ValueError(BACKENDS_MAPPING["safetensors"][1])
+
+ from safetensors import safe_open
+
+ checkpoint = {}
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ else:
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ original_config = OmegaConf.load(original_config_file)
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
+
+ if "control_stage_config" not in original_config.model.params:
+ raise ValueError("`control_stage_config` not present in original config")
+
+ controlnet = convert_controlnet_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=use_linear_projection,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ return controlnet
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a68c4d059c6ac4180cf5d0556b4a9d380497213
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -0,0 +1,796 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.utils import is_accelerate_available, is_accelerate_version
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import DDIMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
+
+ if prev_timestep <= 0:
+ return clean_latents
+
+ # 2. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
+ )
+
+ variance = scheduler._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ # direction pointing to x_t
+ e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
+ dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
+ noise = std_dev_t * randn_tensor(
+ clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
+ )
+ prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
+
+ return prev_latents
+
+
+def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
+ )
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if scheduler.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = scheduler._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
+
+ noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
+ variance ** (0.5) * eta
+ )
+ return noise
+
+
+class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = image.shape[0]
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timestep
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ clean_latents = init_latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents, clean_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ source_prompt: Union[str, List[str]],
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ source_guidance_scale: Optional[float] = 1,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ source_guidance_scale (`float`, *optional*, defaults to 1):
+ Guidance scale for the source prompt. This is useful to control the amount of influence the source
+ prompt for encoding.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.1):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ source_prompt_embeds = self._encode_prompt(
+ source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, clean_latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+ source_latents = latents
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ generator = extra_step_kwargs.pop("generator", None)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ source_latent_model_input = torch.cat([source_latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
+
+ # predict the noise residual
+ concat_latent_model_input = torch.stack(
+ [
+ source_latent_model_input[0],
+ latent_model_input[0],
+ source_latent_model_input[1],
+ latent_model_input[1],
+ ],
+ dim=0,
+ )
+ concat_prompt_embeds = torch.stack(
+ [
+ source_prompt_embeds[0],
+ prompt_embeds[0],
+ source_prompt_embeds[1],
+ prompt_embeds[1],
+ ],
+ dim=0,
+ )
+ concat_noise_pred = self.unet(
+ concat_latent_model_input,
+ t,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_hidden_states=concat_prompt_embeds,
+ ).sample
+
+ # perform guidance
+ (
+ source_noise_pred_uncond,
+ noise_pred_uncond,
+ source_noise_pred_text,
+ noise_pred_text,
+ ) = concat_noise_pred.chunk(4, dim=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
+ source_noise_pred_text - source_noise_pred_uncond
+ )
+
+ # Sample source_latents from the posterior distribution.
+ prev_source_latents = posterior_sample(
+ self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
+ )
+ # Compute noise.
+ noise = compute_noise(
+ self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
+ )
+ source_latents = prev_source_latents
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
+ ).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4f77029ce45497abea4807e97dc8656aaa6a99
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -0,0 +1,470 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from packaging import version
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from ...schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from ...utils import deprecate, logging, replace_example_docstring
+from ..pipeline_flax_utils import FlaxDiffusionPipeline
+from . import FlaxStableDiffusionPipelineOutput
+from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Set to True to use python for loop instead of jax.fori_loop for easier debugging
+DEBUG = False
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import jax
+ >>> import numpy as np
+ >>> from flax.jax_utils import replicate
+ >>> from flax.training.common_utils import shard
+
+ >>> from diffusers import FlaxStableDiffusionPipeline
+
+ >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
+ ... )
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+
+ >>> prng_seed = jax.random.PRNGKey(0)
+ >>> num_inference_steps = 50
+
+ >>> num_samples = jax.device_count()
+ >>> prompt = num_samples * [prompt]
+ >>> prompt_ids = pipeline.prepare_inputs(prompt)
+ # shard inputs and rng
+
+ >>> params = replicate(params)
+ >>> prng_seed = jax.random.split(prng_seed, jax.device_count())
+ >>> prompt_ids = shard(prompt_ids)
+
+ >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+ >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+ ```
+"""
+
+
+class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_inputs(self, prompt: Union[str, List[str]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int,
+ height: int,
+ width: int,
+ guidance_scale: float,
+ latents: Optional[jnp.array] = None,
+ neg_prompt_ids: Optional[jnp.array] = None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
+ context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ # Ensure model output will be `float32` before going into the scheduler
+ guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
+
+ latents_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
+ )
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * params["scheduler"].init_noise_sigma
+
+ if DEBUG:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: Union[float, jnp.array] = 7.5,
+ latents: jnp.array = None,
+ neg_prompt_ids: jnp.array = None,
+ return_dict: bool = True,
+ jit: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
+ by sampling using the supplied random `generator`.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if isinstance(guidance_scale, float):
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
+ if len(prompt_ids.shape) > 2:
+ # Assume sharded
+ guidance_scale = guidance_scale[:, None]
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
+ images = np.asarray(images)
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
+
+
+# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
+# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
+@partial(
+ jax.pmap,
+ in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 4, 5, 6),
+)
+def _p_generate(
+ pipe,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+):
+ return pipe._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec2424ece4dc91fbafd530d525e36d1fb84c4ff
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# NOTE: This file is deprecated and will be removed in a future version.
+# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
+
+from ...utils import deprecate
+from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401
+
+
+deprecate(
+ "stable diffusion controlnet",
+ "0.22.0",
+ "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.",
+ standard_warn=False,
+ stacklevel=3,
+)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a387af364b7467a9f88d537071a48e001f99b69
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -0,0 +1,527 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from ...schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
+from ..pipeline_flax_utils import FlaxDiffusionPipeline
+from . import FlaxStableDiffusionPipelineOutput
+from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Set to True to use python for loop instead of jax.fori_loop for easier debugging
+DEBUG = False
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import jax
+ >>> import numpy as np
+ >>> import jax.numpy as jnp
+ >>> from flax.jax_utils import replicate
+ >>> from flax.training.common_utils import shard
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+ >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline
+
+
+ >>> def create_key(seed=0):
+ ... return jax.random.PRNGKey(seed)
+
+
+ >>> rng = create_key(0)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> response = requests.get(url)
+ >>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_img = init_img.resize((768, 512))
+
+ >>> prompts = "A fantasy landscape, trending on artstation"
+
+ >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4",
+ ... revision="flax",
+ ... dtype=jnp.bfloat16,
+ ... )
+
+ >>> num_samples = jax.device_count()
+ >>> rng = jax.random.split(rng, jax.device_count())
+ >>> prompt_ids, processed_image = pipeline.prepare_inputs(
+ ... prompt=[prompts] * num_samples, image=[init_img] * num_samples
+ ... )
+ >>> p_params = replicate(params)
+ >>> prompt_ids = shard(prompt_ids)
+ >>> processed_image = shard(processed_image)
+
+ >>> output = pipeline(
+ ... prompt_ids=prompt_ids,
+ ... image=processed_image,
+ ... params=p_params,
+ ... prng_seed=rng,
+ ... strength=0.75,
+ ... num_inference_steps=50,
+ ... jit=True,
+ ... height=512,
+ ... width=768,
+ ... ).images
+
+ >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
+ ```
+"""
+
+
+class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for image-to-image generation using Stable Diffusion.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if not isinstance(image, (Image.Image, list)):
+ raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
+
+ if isinstance(image, Image.Image):
+ image = [image]
+
+ processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids, processed_images
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def get_timestep_start(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+
+ return t_start
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ start_timestep: int,
+ num_inference_steps: int,
+ height: int,
+ width: int,
+ guidance_scale: float,
+ noise: Optional[jnp.array] = None,
+ neg_prompt_ids: Optional[jnp.array] = None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
+ context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ latents_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if noise is None:
+ noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
+ else:
+ if noise.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}")
+
+ # Create init_latents
+ init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
+ init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape
+ )
+
+ latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size)
+
+ latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * params["scheduler"].init_noise_sigma
+
+ if DEBUG:
+ # run with python for loop
+ for i in range(start_timestep, num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: Union[float, jnp.array] = 7.5,
+ noise: jnp.array = None,
+ neg_prompt_ids: jnp.array = None,
+ return_dict: bool = True,
+ jit: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt_ids (`jnp.array`):
+ The prompt or prompts to guide the image generation.
+ image (`jnp.array`):
+ Array representing an image batch, that will be used as the starting point for the process.
+ params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
+ prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ noise (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
+ by sampling using the supplied random `generator`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if isinstance(guidance_scale, float):
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
+ if len(prompt_ids.shape) > 2:
+ # Assume sharded
+ guidance_scale = guidance_scale[:, None]
+
+ start_timestep = self.get_timestep_start(num_inference_steps, strength)
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ start_timestep,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ noise,
+ neg_prompt_ids,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ start_timestep,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ noise,
+ neg_prompt_ids,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
+ images = np.asarray(images)
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
+
+
+# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation.
+# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
+@partial(
+ jax.pmap,
+ in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 5, 6, 7, 8),
+)
+def _p_generate(
+ pipe,
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ start_timestep,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ noise,
+ neg_prompt_ids,
+):
+ return pipe._generate(
+ prompt_ids,
+ image,
+ params,
+ prng_seed,
+ start_timestep,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ noise,
+ neg_prompt_ids,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)
+
+
+def preprocess(image, dtype):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = jnp.array(image).astype(dtype) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..abb57f8b62e9aab62b7dc83329ab2a3c1f623532
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -0,0 +1,580 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from packaging import version
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from ...schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
+from ..pipeline_flax_utils import FlaxDiffusionPipeline
+from . import FlaxStableDiffusionPipelineOutput
+from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Set to True to use python for loop instead of jax.fori_loop for easier debugging
+DEBUG = False
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import jax
+ >>> import numpy as np
+ >>> from flax.jax_utils import replicate
+ >>> from flax.training.common_utils import shard
+ >>> import PIL
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from diffusers import FlaxStableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(
+ ... "xvjiarui/stable-diffusion-2-inpainting"
+ ... )
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> prng_seed = jax.random.PRNGKey(0)
+ >>> num_inference_steps = 50
+
+ >>> num_samples = jax.device_count()
+ >>> prompt = num_samples * [prompt]
+ >>> init_image = num_samples * [init_image]
+ >>> mask_image = num_samples * [mask_image]
+ >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(
+ ... prompt, init_image, mask_image
+ ... )
+ # shard inputs and rng
+
+ >>> params = replicate(params)
+ >>> prng_seed = jax.random.split(prng_seed, jax.device_count())
+ >>> prompt_ids = shard(prompt_ids)
+ >>> processed_masked_images = shard(processed_masked_images)
+ >>> processed_masks = shard(processed_masks)
+
+ >>> images = pipeline(
+ ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
+ ... ).images
+ >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+ ```
+"""
+
+
+class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[Image.Image, List[Image.Image]],
+ mask: Union[Image.Image, List[Image.Image]],
+ ):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if not isinstance(image, (Image.Image, list)):
+ raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
+
+ if isinstance(image, Image.Image):
+ image = [image]
+
+ if not isinstance(mask, (Image.Image, list)):
+ raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
+
+ if isinstance(mask, Image.Image):
+ mask = [mask]
+
+ processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image])
+ processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask])
+ # processed_masks[processed_masks < 0.5] = 0
+ processed_masks = processed_masks.at[processed_masks < 0.5].set(0)
+ # processed_masks[processed_masks >= 0.5] = 1
+ processed_masks = processed_masks.at[processed_masks >= 0.5].set(1)
+
+ processed_masked_images = processed_images * (processed_masks < 0.5)
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids, processed_masked_images, processed_masks
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ mask: jnp.array,
+ masked_image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int,
+ height: int,
+ width: int,
+ guidance_scale: float,
+ latents: Optional[jnp.array] = None,
+ neg_prompt_ids: Optional[jnp.array] = None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
+ context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ latents_shape = (
+ batch_size,
+ self.vae.config.latent_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ prng_seed, mask_prng_seed = jax.random.split(prng_seed)
+
+ masked_image_latent_dist = self.vae.apply(
+ {"params": params["vae"]}, masked_image, method=self.vae.encode
+ ).latent_dist
+ masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+ del mask_prng_seed
+
+ mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
+
+ # 8. Check that sizes of mask, masked image and latents match
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ def loop_body(step, args):
+ latents, mask, masked_image_latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+ mask_input = jnp.concatenate([mask] * 2)
+ masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1)
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, mask, masked_image_latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
+ )
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * params["scheduler"].init_noise_sigma
+
+ if DEBUG:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, mask, masked_image_latents, scheduler_state = loop_body(
+ i, (latents, mask, masked_image_latents, scheduler_state)
+ )
+ else:
+ latents, _, _, _ = jax.lax.fori_loop(
+ 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state)
+ )
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ mask: jnp.array,
+ masked_image: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.KeyArray,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: Union[float, jnp.array] = 7.5,
+ latents: jnp.array = None,
+ neg_prompt_ids: jnp.array = None,
+ return_dict: bool = True,
+ jit: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
+ by sampling using the supplied random `generator`.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic")
+ mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest")
+
+ if isinstance(guidance_scale, float):
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
+ if len(prompt_ids.shape) > 2:
+ # Assume sharded
+ guidance_scale = guidance_scale[:, None]
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ mask,
+ masked_image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ mask,
+ masked_image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
+ images = np.asarray(images)
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
+
+
+# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
+# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
+@partial(
+ jax.pmap,
+ in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 6, 7, 8),
+)
+def _p_generate(
+ pipe,
+ prompt_ids,
+ mask,
+ masked_image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+):
+ return pipe._generate(
+ prompt_ids,
+ mask,
+ masked_image,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ neg_prompt_ids,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)
+
+
+def preprocess_image(image, dtype):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = jnp.array(image).astype(dtype) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, dtype):
+ w, h = mask.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w, h))
+ mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0
+ mask = jnp.expand_dims(mask, axis=(0, 1))
+
+ return mask
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb02f6cb321cb02ec5bd7badc0f6c73f06ae1e41
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -0,0 +1,485 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
+from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__)
+
+
+class OnnxStableDiffusionPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPImageProcessor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int],
+ width: Optional[int],
+ callback_steps: int,
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ latents: Optional[np.ndarray] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ One or a list of [numpy generator(s)](TODO) to make generation deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if generator is None:
+ generator = np.random
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # get the initial random noise unless the user supplied it
+ latents_dtype = prompt_embeds.dtype
+ latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
+ if latents is None:
+ latents = generator.randn(*latents_shape).astype(latents_dtype)
+ elif latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ latents = latents * np.float64(self.scheduler.init_noise_sigma)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
+ deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..293ed7d981b80a30cfad9a4a84478c7209a1cea7
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -0,0 +1,552 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
+def preprocess(image):
+ warnings.warn(
+ (
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead"
+ ),
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPImageProcessor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ callback_steps: int,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[np.ndarray, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # check inputs. Raise error if not correct
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ image = preprocess(image).cpu().numpy()
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ latents_dtype = prompt_embeds.dtype
+ image = image.astype(latents_dtype)
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = len(prompt) // init_latents.shape[0]
+ init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0)
+ elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
+ )
+ else:
+ init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0)
+
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
+ timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
+
+ # add noise to latents using the timesteps
+ noise = generator.randn(*init_latents.shape).astype(latents_dtype)
+ init_latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
+ )
+ init_latents = init_latents.numpy()
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].numpy()
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
+ 0
+ ]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # safety_checker does not support batched inputs yet
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bb39c4b1c617ea07e71355364f6476f6178e806
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -0,0 +1,560 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+NUM_UNET_INPUT_CHANNELS = 9
+NUM_LATENT_CHANNELS = 4
+
+
+def prepare_mask_and_masked_image(image, mask, latents_shape):
+ image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = image.astype(np.float32) / 127.5 - 1.0
+
+ image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
+ masked_image = image * (image_mask < 127.5)
+
+ mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ return mask, masked_image
+
+
+class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPImageProcessor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int],
+ width: Optional[int],
+ callback_steps: int,
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: PIL.Image.Image,
+ mask_image: PIL.Image.Image,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ latents: Optional[np.ndarray] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ num_channels_latents = NUM_LATENT_CHANNELS
+ latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
+ latents_dtype = prompt_embeds.dtype
+ if latents is None:
+ latents = generator.randn(*latents_shape).astype(latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # prepare mask and masked_image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:])
+ mask = mask.astype(latents.dtype)
+ masked_image = masked_image.astype(latents.dtype)
+
+ masked_image_latents = self.vae_encoder(sample=masked_image)[0]
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt
+ mask = mask.repeat(batch_size * num_images_per_prompt, 0)
+ masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)
+
+ mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+
+ unet_input_channels = NUM_UNET_INPUT_CHANNELS
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels:
+ raise ValueError(
+ "Incorrect configuration settings! The config of `pipeline.unet` expects"
+ f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * np.float64(self.scheduler.init_noise_sigma)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ # concat latents, mask, masked_image_latnets in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+ latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
+ 0
+ ]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # safety_checker does not support batched inputs yet
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef7a781451c2757e5657aba9c1ff24276890524
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -0,0 +1,539 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
+from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ return mask
+
+
+class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to
+ provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPImageProcessor
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[np.ndarray, PIL.Image.Image] = None,
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`nd.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`nd.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # check inputs. Raise error if not correct
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ latents_dtype = prompt_embeds.dtype
+ image = image.astype(latents_dtype)
+
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0)
+ init_latents_orig = init_latents
+
+ # preprocess mask
+ if not isinstance(mask_image, np.ndarray):
+ mask_image = preprocess_mask(mask_image, 8)
+ mask_image = mask_image.astype(latents_dtype)
+ mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError("The mask and image should be the same size!")
+
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
+ timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
+
+ # add noise to latents using the timesteps
+ noise = generator.randn(*init_latents.shape).astype(latents_dtype)
+ init_latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
+ )
+ init_latents = init_latents.numpy()
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].numpy()
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
+ 0
+ ]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ ).prev_sample
+
+ latents = latents.numpy()
+
+ init_latents_proper = self.scheduler.add_noise(
+ torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t]))
+ )
+
+ init_latents_proper = init_latents_proper.numpy()
+
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # There will throw an error if use safety_checker batchsize>1
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..56681391aeeba7d0146cc4f296e4ead20204c33e
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -0,0 +1,391 @@
+from logging import getLogger
+from typing import Any, Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...schedulers import DDPMScheduler
+from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ..pipeline_utils import ImagePipelineOutput
+from . import StableDiffusionUpscalePipeline
+
+
+logger = getLogger(__name__)
+
+
+NUM_LATENT_CHANNELS = 4
+NUM_UNET_INPUT_CHANNELS = 7
+
+ORT_TO_PT_TYPE = {
+ "float16": torch.float16,
+ "float32": torch.float32,
+}
+
+
+def preprocess(image):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32
+
+ image = [np.array(i.resize((w, h)))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ return image
+
+
+class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
+ def __init__(
+ self,
+ vae: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: Any,
+ unet: OnnxRuntimeModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: Any,
+ max_noise_level: int = 350,
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ watermarker=None,
+ max_noise_level=max_noise_level,
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 20,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ noise_level TODO
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(prompt, image, noise_level, callback_steps)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
+
+ # 4. Preprocess image
+ image = preprocess(image)
+ image = image.cpu()
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Add noise to image
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
+ noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
+ image = self.low_res_scheduler.add_noise(image, noise, noise_level)
+
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
+ image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
+ noise_level = np.concatenate([noise_level] * image.shape[0])
+
+ # 6. Prepare latent variables
+ height, width = image.shape[2:]
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ NUM_LATENT_CHANNELS,
+ height,
+ width,
+ latents_dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that sizes of image and latents match
+ num_channels_image = image.shape[1]
+ if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
+ raise ValueError(
+ "Incorrect configuration settings! The config of `pipeline.unet` expects"
+ f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = np.concatenate([latent_model_input, image], axis=1)
+
+ # timestep to tensor
+ timestep = np.array([t], dtype=timestep_dtype)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=text_embeddings,
+ class_labels=noise_level.astype(np.int64),
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
+ ).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 10. Post-processing
+ image = self.decode_latents(latents.float())
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.08333 * latents
+ image = self.vae(latent_sample=latents)[0]
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+ return image
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device,
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ # no positional arguments to text_encoder
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids.int().to(device),
+ # attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ # if hasattr(uncond_input, "attention_mask"):
+ # attention_mask = uncond_input.attention_mask.to(device)
+ # else:
+ # attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ input_ids=uncond_input.input_ids.int().to(device),
+ # attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ if do_classifier_free_guidance:
+ seq_len = uncond_embeddings.shape[1]
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
+ uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
+
+ return prompt_embeds
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..54927049571cadd73dfc4e2135f48baa34d0011e
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,732 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a5d7eb13624524d944de6d45aa92a208d70ac0
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -0,0 +1,1032 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import Attention
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionAttendAndExcitePipeline
+
+ >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
+ ... ).to("cuda")
+
+
+ >>> prompt = "a cat and a frog"
+
+ >>> # use get_indices function to find out indices of the tokens you want to alter
+ >>> pipe.get_indices(prompt)
+ {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'}
+
+ >>> token_indices = [2, 5]
+ >>> seed = 6141
+ >>> generator = torch.Generator("cuda").manual_seed(seed)
+
+ >>> images = pipe(
+ ... prompt=prompt,
+ ... token_indices=token_indices,
+ ... guidance_scale=7.5,
+ ... generator=generator,
+ ... num_inference_steps=50,
+ ... max_iter_to_alter=25,
+ ... ).images
+
+ >>> image = images[0]
+ >>> image.save(f"../images/{prompt}_{seed}.png")
+ ```
+"""
+
+
+class AttentionStore:
+ @staticmethod
+ def get_empty_store():
+ return {"down": [], "mid": [], "up": []}
+
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
+ if self.cur_att_layer >= 0 and is_cross:
+ if attn.shape[1] == np.prod(self.attn_res):
+ self.step_store[place_in_unet].append(attn)
+
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers:
+ self.cur_att_layer = 0
+ self.between_steps()
+
+ def between_steps(self):
+ self.attention_store = self.step_store
+ self.step_store = self.get_empty_store()
+
+ def get_average_attention(self):
+ average_attention = self.attention_store
+ return average_attention
+
+ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
+ """Aggregates the attention across the different layers and heads at the specified resolution."""
+ out = []
+ attention_maps = self.get_average_attention()
+ for location in from_where:
+ for item in attention_maps[location]:
+ cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
+ out.append(cross_maps)
+ out = torch.cat(out, dim=0)
+ out = out.sum(0) / out.shape[0]
+ return out
+
+ def reset(self):
+ self.cur_att_layer = 0
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+ def __init__(self, attn_res):
+ """
+ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
+ process
+ """
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+ self.curr_step_index = 0
+ self.attn_res = attn_res
+
+
+class AttendExciteAttnProcessor:
+ def __init__(self, attnstore, place_in_unet):
+ super().__init__()
+ self.attnstore = attnstore
+ self.place_in_unet = place_in_unet
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ query = attn.to_q(hidden_states)
+
+ is_cross = encoder_hidden_states is not None
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+ # only need to store attention maps during the Attend and Excite process
+ if attention_probs.requires_grad:
+ self.attnstore(attention_probs, is_cross, self.place_in_unet)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ indices,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
+ indices_is_list_list_ints = (
+ isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
+ )
+
+ if not indices_is_list_ints and not indices_is_list_list_ints:
+ raise TypeError("`indices` must be a list of ints or a list of a list of ints")
+
+ if indices_is_list_ints:
+ indices_batch_size = 1
+ elif indices_is_list_list_ints:
+ indices_batch_size = len(indices)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if indices_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @staticmethod
+ def _compute_max_attention_per_index(
+ attention_maps: torch.Tensor,
+ indices: List[int],
+ ) -> List[torch.Tensor]:
+ """Computes the maximum attention value for each of the tokens we wish to alter."""
+ attention_for_text = attention_maps[:, :, 1:-1]
+ attention_for_text *= 100
+ attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)
+
+ # Shift indices since we removed the first token
+ indices = [index - 1 for index in indices]
+
+ # Extract the maximum values
+ max_indices_list = []
+ for i in indices:
+ image = attention_for_text[:, :, i]
+ smoothing = GaussianSmoothing().to(attention_maps.device)
+ input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect")
+ image = smoothing(input).squeeze(0).squeeze(0)
+ max_indices_list.append(image.max())
+ return max_indices_list
+
+ def _aggregate_and_get_max_attention_per_token(
+ self,
+ indices: List[int],
+ ):
+ """Aggregates the attention for each token and computes the max activation value for each token to alter."""
+ attention_maps = self.attention_store.aggregate_attention(
+ from_where=("up", "down", "mid"),
+ )
+ max_attention_per_index = self._compute_max_attention_per_index(
+ attention_maps=attention_maps,
+ indices=indices,
+ )
+ return max_attention_per_index
+
+ @staticmethod
+ def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor:
+ """Computes the attend-and-excite loss using the maximum attention value for each token."""
+ losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index]
+ loss = max(losses)
+ return loss
+
+ @staticmethod
+ def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
+ """Update the latent according to the computed loss."""
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
+ latents = latents - step_size * grad_cond
+ return latents
+
+ def _perform_iterative_refinement_step(
+ self,
+ latents: torch.Tensor,
+ indices: List[int],
+ loss: torch.Tensor,
+ threshold: float,
+ text_embeddings: torch.Tensor,
+ step_size: float,
+ t: int,
+ max_refinement_steps: int = 20,
+ ):
+ """
+ Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code
+ according to our loss objective until the given threshold is reached for all tokens.
+ """
+ iteration = 0
+ target_loss = max(0, 1.0 - threshold)
+ while loss > target_loss:
+ iteration += 1
+
+ latents = latents.clone().detach().requires_grad_(True)
+ self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
+ indices=indices,
+ )
+
+ loss = self._compute_loss(max_attention_per_index)
+
+ if loss != 0:
+ latents = self._update_latent(latents, loss, step_size)
+
+ logger.info(f"\t Try {iteration}. loss: {loss}")
+
+ if iteration >= max_refinement_steps:
+ logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ")
+ break
+
+ # Run one more time but don't compute gradients and update the latents.
+ # We just need to compute the new loss - the grad update will occur below
+ latents = latents.clone().detach().requires_grad_(True)
+ _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
+ indices=indices,
+ )
+ loss = self._compute_loss(max_attention_per_index)
+ logger.info(f"\t Finished with loss of: {loss}")
+ return loss, latents, max_attention_per_index
+
+ def register_attention_control(self):
+ attn_procs = {}
+ cross_att_count = 0
+ for name in self.unet.attn_processors.keys():
+ if name.startswith("mid_block"):
+ place_in_unet = "mid"
+ elif name.startswith("up_blocks"):
+ place_in_unet = "up"
+ elif name.startswith("down_blocks"):
+ place_in_unet = "down"
+ else:
+ continue
+
+ cross_att_count += 1
+ attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)
+
+ self.unet.set_attn_processor(attn_procs)
+ self.attention_store.num_att_layers = cross_att_count
+
+ def get_indices(self, prompt: str) -> Dict[str, int]:
+ """Utility function to list the indices of the tokens you wish to alte"""
+ ids = self.tokenizer(prompt).input_ids
+ indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))}
+ return indices
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ token_indices: Union[List[int], List[List[int]]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ max_iter_to_alter: int = 25,
+ thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
+ scale_factor: int = 20,
+ attn_res: Optional[Tuple[int]] = (16, 16),
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ token_indices (`List[int]`):
+ The token indices to alter with attend-and-excite.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ max_iter_to_alter (`int`, *optional*, defaults to `25`):
+ Number of denoising steps to apply attend-and-excite. The first denoising steps are
+ where the attend-and-excite is applied. I.e. if `max_iter_to_alter` is 25 and there are a total of `30`
+ denoising steps, the first 25 denoising steps will apply attend-and-excite and the last 5 will not
+ apply attend-and-excite.
+ thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`):
+ Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
+ scale_factor (`int`, *optional*, default to 20):
+ Scale factor that controls the step size of each Attend and Excite update.
+ attn_res (`tuple`, *optional*, default computed from width and height):
+ The 2D resolution of the semantic attention map.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`. :type attention_store: object
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ token_indices,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ if attn_res is None:
+ attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
+ self.attention_store = AttentionStore(attn_res)
+ self.register_attention_control()
+
+ # default config for step size from original repo
+ scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps))
+ step_size = scale_factor * np.sqrt(scale_range)
+
+ text_embeddings = (
+ prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds
+ )
+
+ if isinstance(token_indices[0], int):
+ token_indices = [token_indices]
+
+ indices = []
+
+ for ind in token_indices:
+ indices = indices + [ind] * num_images_per_prompt
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Attend and excite process
+ with torch.enable_grad():
+ latents = latents.clone().detach().requires_grad_(True)
+ updated_latents = []
+ for latent, index, text_embedding in zip(latents, indices, text_embeddings):
+ # Forward pass of denoising with text conditioning
+ latent = latent.unsqueeze(0)
+ text_embedding = text_embedding.unsqueeze(0)
+
+ self.unet(
+ latent,
+ t,
+ encoder_hidden_states=text_embedding,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
+ indices=index,
+ )
+
+ loss = self._compute_loss(max_attention_per_index=max_attention_per_index)
+
+ # If this is an iterative refinement step, verify we have reached the desired threshold for all
+ if i in thresholds.keys() and loss > 1.0 - thresholds[i]:
+ loss, latent, max_attention_per_index = self._perform_iterative_refinement_step(
+ latents=latent,
+ indices=index,
+ loss=loss,
+ threshold=thresholds[i],
+ text_embeddings=text_embedding,
+ step_size=step_size[i],
+ t=t,
+ )
+
+ # Perform gradient update
+ if i < max_iter_to_alter:
+ if loss != 0:
+ latent = self._update_latent(
+ latents=latent,
+ loss=loss,
+ step_size=step_size[i],
+ )
+ logger.info(f"Iteration {i} | Loss: {loss:0.4f}")
+
+ updated_latents.append(latent)
+
+ latents = torch.cat(updated_latents, dim=0)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+class GaussianSmoothing(torch.nn.Module):
+ """
+ Arguments:
+ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input
+ using a depthwise convolution.
+ channels (int, sequence): Number of channels of the input tensors. Output will
+ have this number of channels as well.
+ kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the
+ gaussian kernel. dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+
+ # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2
+ def __init__(
+ self,
+ channels: int = 1,
+ kernel_size: int = 3,
+ sigma: float = 0.5,
+ dim: int = 2,
+ ):
+ super().__init__()
+
+ if isinstance(kernel_size, int):
+ kernel_size = [kernel_size] * dim
+ if isinstance(sigma, float):
+ sigma = [sigma] * dim
+
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer("weight", kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
+
+ def forward(self, input):
+ """
+ Arguments:
+ Apply gaussian filter to input.
+ input (torch.Tensor): Input to apply gaussian filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7555e2ebad4c7f6045f3975b61f271a97ec8587
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# NOTE: This file is deprecated and will be removed in a future version.
+# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
+from ...utils import deprecate
+from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401
+from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401
+
+
+deprecate(
+ "stable diffusion controlnet",
+ "0.22.0",
+ "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.",
+ standard_warn=False,
+ stacklevel=3,
+)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..cae0f3a347de6248669a8508644c6ce7fe98d441
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -0,0 +1,727 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ depth_estimator: DPTForDepthEstimation,
+ feature_extractor: DPTFeatureExtractor,
+ ):
+ super().__init__()
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ depth_estimator=depth_estimator,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+ else:
+ image = list(image)
+
+ if isinstance(image[0], PIL.Image.Image):
+ width, height = image[0].size
+ elif isinstance(image[0], np.ndarray):
+ width, height = image[0].shape[:-1]
+ else:
+ height, width = image[0].shape[-2:]
+
+ if depth_map is None:
+ pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+ pixel_values = pixel_values.to(device=device)
+ # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
+ # So we use `torch.autocast` here for half precision inference.
+ context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
+ with context_manger:
+ depth_map = self.depth_estimator(pixel_values).predicted_depth
+ else:
+ depth_map = depth_map.to(device=device, dtype=dtype)
+
+ depth_map = torch.nn.functional.interpolate(
+ depth_map.unsqueeze(1),
+ size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
+ depth_map = depth_map.to(dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if depth_map.shape[0] < batch_size:
+ repeat_by = batch_size // depth_map.shape[0]
+ depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
+
+ depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
+ return depth_map
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ depth_map: Optional[torch.FloatTensor] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can accept image latents as `image` only if `depth_map` is not `None`.
+ depth_map (`torch.FloatTensor`, *optional*):
+ depth prediction that will be used as additional conditioning for the image generation process. If not
+ defined, it will automatically predicts the depth via `self.depth_estimator`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> from diffusers import StableDiffusionDepth2ImgPipeline
+
+ >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-depth",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipe.to("cuda")
+
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> init_image = Image.open(requests.get(url, stream=True).raw)
+ >>> prompt = "two tigers"
+ >>> n_propmt = "bad, deformed, ugly, bad anotomy"
+ >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ strength,
+ callback_steps,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare depth mask
+ depth_mask = self.prepare_depth_map(
+ image,
+ depth_map,
+ batch_size * num_images_per_prompt,
+ do_classifier_free_guidance,
+ prompt_embeds.dtype,
+ device,
+ )
+
+ # 5. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 6. Set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 7. Prepare latent variables
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d5953808f173fa9addf3206d958b2f7ec5c056
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
@@ -0,0 +1,1532 @@
+# Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ BaseOutput,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class DiffEditInversionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ latents (`torch.FloatTensor`)
+ inverted latents tensor
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps,
+ batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the
+ diffusion pipeline.
+ """
+
+ latents: torch.FloatTensor
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+EXAMPLE_DOC_STRING = """
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionDiffEditPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
+
+ >>> init_image = download_image(img_url).resize((768, 768))
+
+ >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.enable_model_cpu_offload()
+
+ >>> mask_prompt = "A bowl of fruits"
+ >>> prompt = "A bowl of pears"
+
+ >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt)
+ >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents
+ >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0]
+ ```
+"""
+
+EXAMPLE_INVERT_DOC_STRING = """
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionDiffEditPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
+
+ >>> init_image = download_image(img_url).resize((768, 768))
+
+ >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.enable_model_cpu_offload()
+
+ >>> prompt = "A bowl of fruits"
+
+ >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents
+ ```
+"""
+
+
+def auto_corr_loss(hidden_states, generator=None):
+ reg_loss = 0.0
+ for i in range(hidden_states.shape[0]):
+ for j in range(hidden_states.shape[1]):
+ noise = hidden_states[i : i + 1, j : j + 1, :, :]
+ while True:
+ roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
+
+ if noise.shape[2] <= 8:
+ break
+ noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2)
+ return reg_loss
+
+
+def kl_divergence(hidden_states):
+ return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def preprocess_mask(mask, batch_size: int = 1):
+ if not isinstance(mask, torch.Tensor):
+ # preprocess mask
+ if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray):
+ mask = [mask]
+
+ if isinstance(mask, list):
+ if isinstance(mask[0], PIL.Image.Image):
+ mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask]
+ if isinstance(mask[0], np.ndarray):
+ mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0)
+ mask = torch.from_numpy(mask)
+ elif isinstance(mask[0], torch.Tensor):
+ mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ # Check mask shape
+ if batch_size > 1:
+ if mask.shape[0] == 1:
+ mask = torch.cat([mask] * batch_size)
+ elif mask.shape[0] > 1 and mask.shape[0] != batch_size:
+ raise ValueError(
+ f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} "
+ f"inferred by prompt inputs"
+ )
+
+ if mask.shape[1] != 1:
+ raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("`mask_image` should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ return mask
+
+
+class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ inverse_scheduler (`[DDIMInverseScheduler]`):
+ A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ inverse_scheduler: DDIMInverseScheduler,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ inverse_scheduler=inverse_scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (strength is None) or (strength is not None and (strength < 0 or strength > 1)):
+ raise ValueError(
+ f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def check_source_inputs(
+ self,
+ source_prompt=None,
+ source_negative_prompt=None,
+ source_prompt_embeds=None,
+ source_negative_prompt_embeds=None,
+ ):
+ if source_prompt is not None and source_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}."
+ " Please make sure to only forward one of the two."
+ )
+ elif source_prompt is None and source_prompt_embeds is None:
+ raise ValueError(
+ "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined."
+ )
+ elif source_prompt is not None and (
+ not isinstance(source_prompt, str) and not isinstance(source_prompt, list)
+ ):
+ raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}")
+
+ if source_negative_prompt is not None and source_negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:"
+ f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if source_prompt_embeds is not None and source_negative_prompt_embeds is not None:
+ if source_prompt_embeds.shape != source_negative_prompt_embeds.shape:
+ raise ValueError(
+ "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed"
+ f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !="
+ f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def get_inverse_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+
+ # safety for t_start overflow to prevent empty timsteps slice
+ if t_start == 0:
+ return self.inverse_scheduler.timesteps, num_inference_steps
+ timesteps = self.inverse_scheduler.timesteps[:-t_start]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents
+ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0)
+ else:
+ latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ latents = self.vae.config.scaling_factor * latents
+
+ if batch_size != latents.shape[0]:
+ if batch_size % latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_latents_per_image = batch_size // latents.shape[0]
+ latents = torch.cat([latents] * additional_latents_per_image, dim=0)
+ else:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ latents = torch.cat([latents], dim=0)
+
+ return latents
+
+ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
+ pred_type = self.inverse_scheduler.config.prediction_type
+ alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ if pred_type == "epsilon":
+ return model_output
+ elif pred_type == "sample":
+ return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
+ elif pred_type == "v_prediction":
+ return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
+ )
+
+ @torch.no_grad()
+ def generate_mask(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ target_prompt: Optional[Union[str, List[str]]] = None,
+ target_negative_prompt: Optional[Union[str, List[str]]] = None,
+ target_prompt_embeds: Optional[torch.FloatTensor] = None,
+ target_negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ source_prompt: Optional[Union[str, List[str]]] = None,
+ source_negative_prompt: Optional[Union[str, List[str]]] = None,
+ source_prompt_embeds: Optional[torch.FloatTensor] = None,
+ source_negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ num_maps_per_mask: Optional[int] = 10,
+ mask_encode_strength: Optional[float] = 0.5,
+ mask_thresholding_ratio: Optional[float] = 3.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "np",
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function used to generate a latent mask given a mask prompt, a target prompt, and an image.
+
+ Args:
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be used for computing the mask.
+ target_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass
+ `prompt_embeds`. instead.
+ target_negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ target_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ target_negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ source_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit:
+ Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If
+ not defined, one has to pass `source_prompt_embeds` or `source_image` instead.
+ source_negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit:
+ Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If
+ not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead.
+ source_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text
+ inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from
+ `source_prompt` input argument.
+ source_negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily
+ tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from
+ `source_negative_prompt` input argument.
+ num_maps_per_mask (`int`, *optional*, defaults to 10):
+ The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit:
+ Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf).
+ mask_encode_strength (`float`, *optional*, defaults to 0.5):
+ Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in
+ [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](
+ https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1.
+ mask_thresholding_ratio (`float`, *optional*, defaults to 3.0):
+ The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before
+ mask binarization.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a
+ `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel
+ binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise
+ the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width //
+ self.vae_scale_factor)`.
+ """
+
+ # 1. Check inputs (Provide dummy argument for callback_steps)
+ self.check_inputs(
+ target_prompt,
+ mask_encode_strength,
+ 1,
+ target_negative_prompt,
+ target_prompt_embeds,
+ target_negative_prompt_embeds,
+ )
+
+ self.check_source_inputs(
+ source_prompt,
+ source_negative_prompt,
+ source_prompt_embeds,
+ source_negative_prompt_embeds,
+ )
+
+ if (num_maps_per_mask is None) or (
+ num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0)
+ ):
+ raise ValueError(
+ f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type"
+ f" {type(num_maps_per_mask)}."
+ )
+
+ if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0:
+ raise ValueError(
+ f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type"
+ f" {type(mask_thresholding_ratio)}."
+ )
+
+ # 2. Define call parameters
+ if target_prompt is not None and isinstance(target_prompt, str):
+ batch_size = 1
+ elif target_prompt is not None and isinstance(target_prompt, list):
+ batch_size = len(target_prompt)
+ else:
+ batch_size = target_prompt_embeds.shape[0]
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompts
+ (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
+ target_prompt_embeds = self._encode_prompt(
+ target_prompt,
+ device,
+ num_maps_per_mask,
+ do_classifier_free_guidance,
+ target_negative_prompt,
+ prompt_embeds=target_prompt_embeds,
+ negative_prompt_embeds=target_negative_prompt_embeds,
+ )
+
+ source_prompt_embeds = self._encode_prompt(
+ source_prompt,
+ device,
+ num_maps_per_mask,
+ do_classifier_free_guidance,
+ source_negative_prompt,
+ prompt_embeds=source_prompt_embeds,
+ negative_prompt_embeds=source_negative_prompt_embeds,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
+
+ # 5. Set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device)
+ encode_timestep = timesteps[0]
+
+ # 6. Prepare image latents and add noise with specified strength
+ image_latents = self.prepare_image_latents(
+ image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator
+ )
+ noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype)
+ image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep)
+
+ latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2))
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep)
+
+ # 7. Predict the noise residual
+ prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds])
+ noise_pred = self.unet(
+ latent_model_input,
+ encode_timestep,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4)
+ noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src)
+ noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond)
+ else:
+ noise_pred_source, noise_pred_target = noise_pred.chunk(2)
+
+ # 8. Compute the mask from the absolute difference of predicted noise residuals
+ # TODO: Consider smoothing mask guidance map
+ mask_guidance_map = (
+ torch.abs(noise_pred_target - noise_pred_source)
+ .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:])
+ .mean([1, 2])
+ )
+ clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio
+ semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude
+ semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1)
+ mask_image = semantic_mask_image.cpu().numpy()
+
+ # 9. Convert to Numpy array or PIL.
+ if output_type == "pil":
+ mask_image = self.image_processor.numpy_to_pil(mask_image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return mask_image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
+ def invert(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ num_inference_steps: int = 50,
+ inpaint_strength: float = 0.8,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ decode_latents: bool = False,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ lambda_auto_corr: float = 20.0,
+ lambda_kl: float = 20.0,
+ num_reg_steps: int = 0,
+ num_auto_corr_rolls: int = 5,
+ ):
+ r"""
+ Function used to generate inverted latents given a prompt and image.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`.
+ inpaint_strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and
+ 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more
+ noise the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ decode_latents (`bool`, *optional*, defaults to `False`):
+ Whether or not to decode the inverted latents into a generated image. Setting this argument to `True`
+ will decode all inverted latents for each timestep into a list of generated images.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ lambda_auto_corr (`float`, *optional*, defaults to 20.0):
+ Lambda parameter to control auto correction
+ lambda_kl (`float`, *optional*, defaults to 20.0):
+ Lambda parameter to control Kullback–Leibler divergence output
+ num_reg_steps (`int`, *optional*, defaults to 0):
+ Number of regularization loss steps
+ num_auto_corr_rolls (`int`, *optional*, defaults to 5):
+ Number of auto correction roll steps
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or
+ `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`]
+ if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted
+ latents tensors ordered by increasing noise, and then second is the corresponding decoded images if
+ `decode_latents` is `True`, otherwise `None`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ inpaint_strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. Prepare latent variables
+ num_images_per_prompt = 1
+ latents = self.prepare_image_latents(
+ image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator
+ )
+
+ # 5. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 6. Prepare timesteps
+ self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device)
+
+ # 7. Noising loop where we obtain the intermediate noised latent image for each timestep.
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
+ inverted_latents = []
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero)
+ if num_reg_steps > 0:
+ with torch.enable_grad():
+ for _ in range(num_reg_steps):
+ if lambda_auto_corr > 0:
+ for _ in range(num_auto_corr_rolls):
+ var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
+
+ # Derive epsilon from model output before regularizing to IID standard normal
+ var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
+
+ l_ac = auto_corr_loss(var_epsilon, generator=generator)
+ l_ac.backward()
+
+ grad = var.grad.detach() / num_auto_corr_rolls
+ noise_pred = noise_pred - lambda_auto_corr * grad
+
+ if lambda_kl > 0:
+ var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
+
+ # Derive epsilon from model output before regularizing to IID standard normal
+ var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
+
+ l_kld = kl_divergence(var_epsilon)
+ l_kld.backward()
+
+ grad = var.grad.detach()
+ noise_pred = noise_pred - lambda_kl * grad
+
+ noise_pred = noise_pred.detach()
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
+ inverted_latents.append(latents.detach().clone())
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ assert len(inverted_latents) == len(timesteps)
+ latents = torch.stack(list(reversed(inverted_latents)), 1)
+
+ # 8. Post-processing
+ image = None
+ if decode_latents:
+ image = self.decode_latents(latents.flatten(0, 1))
+
+ # 9. Convert to PIL.
+ if decode_latents and output_type == "pil":
+ image = self.image_processor.numpy_to_pil(image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (latents, image)
+
+ return DiffEditInversionPipelineOutput(latents=latents, images=image)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ image_latents: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ inpaint_strength: Optional[float] = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask
+ will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be
+ converted to a single channel (luminance) before use. If it's a tensor, it should contain one color
+ channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`.
+ image_latents (`PIL.Image.Image` or `torch.FloatTensor`):
+ Partially noised image latents from the inversion process to be used as inputs for image generation.
+ inpaint_strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ inpaint_strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if mask_image is None:
+ raise ValueError(
+ "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts."
+ )
+ if image_latents is None:
+ raise ValueError(
+ "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images."
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess mask
+ mask_image = preprocess_mask(mask_image, batch_size)
+ latent_height, latent_width = mask_image.shape[-2:]
+ mask_image = torch.cat([mask_image] * num_images_per_prompt)
+ mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype)
+
+ # 5. Set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device)
+
+ # 6. Preprocess image latents
+ if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents):
+ image_latents = torch.cat(image_latents).detach()
+ elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5:
+ image_latents = image_latents.detach()
+ else:
+ image_latents = self.image_processor.preprocess(image_latents).detach()
+
+ latent_shape = (self.vae.config.latent_channels, latent_height, latent_width)
+ if image_latents.shape[-3:] != latent_shape:
+ raise ValueError(
+ f"Each latent image in `image_latents` must have shape {latent_shape}, "
+ f"but has shape {image_latents.shape[-3:]}"
+ )
+ if image_latents.ndim == 4:
+ image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape)
+ if image_latents.shape[:2] != (batch_size, len(timesteps)):
+ raise ValueError(
+ f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}"
+ f" timesteps, but has batch size {image_latents.shape[0]} with latent images from"
+ f" {image_latents.shape[1]} timesteps."
+ )
+ image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1)
+ image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ latents = image_latents[0].clone()
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # mask with inverted latents from appropriate timestep - use original image latent for last step
+ latents = latents * mask_image + image_latents[i] * (1 - mask_image)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebcb55a7cff90cc06d1c05476013c05526f19f19
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -0,0 +1,394 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import deprecate, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionImageVariationPipeline(DiffusionPipeline):
+ r"""
+ Pipeline to generate variations from an input image using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
+ # we should give a descriptive message if the pipeline doesn't have one.
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
+ configuration of
+ [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
+ `CLIPImageProcessor`
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab4420f2838675d06b2558c37cd31a133915f423
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -0,0 +1,764 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionImg2ImgPipeline
+
+ >>> device = "cuda"
+ >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
+ >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ >>> response = requests.get(url)
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_image = init_image.resize((768, 512))
+
+ >>> prompt = "A fantasy landscape, trending on artstation"
+
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+ >>> images[0].save("fantasy_landscape.png")
+ ```
+"""
+
+
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionImg2ImgPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b11ebfb6cfc55c888fd1fdfe2cb68151a6567c88
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,1030 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+
+
+ It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such
+ as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default
+ text-to-image stable diffusion checkpoints, such as
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with
+ this pipeline, but might be less performant.
+
+
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 9:
+ logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..049e3d18f3de867969af00e6e141335576286263
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -0,0 +1,738 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess_image(image, batch_size):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, batch_size, scale_factor=8):
+ if not isinstance(mask, torch.FloatTensor):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = np.vstack([mask[None]] * batch_size)
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+ else:
+ valid_mask_channel_sizes = [1, 3]
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
+ if mask.shape[3] in valid_mask_channel_sizes:
+ mask = mask.permute(0, 3, 1, 2)
+ elif mask.shape[1] not in valid_mask_channel_sizes:
+ raise ValueError(
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
+ f" but received mask of shape {tuple(mask.shape)}"
+ )
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
+ mask = mask.mean(dim=1, keepdim=True)
+ h, w = mask.shape[-2:]
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
+ return mask
+
+
+class StableDiffusionInpaintPipelineLegacy(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ deprecation_message = (
+ f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality"
+ "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533"
+ "for more information."
+ )
+ deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False)
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
+ image = image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+ init_latents_orig = init_latents
+
+ # add noise to latents using the timesteps
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ add_predicted_noise: Optional[bool] = False,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
+ expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
+ that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ add_predicted_noise (`bool`, *optional*, defaults to True):
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
+ the reverse diffusion process
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image and mask
+ if not isinstance(image, torch.FloatTensor):
+ image = preprocess_image(image, batch_size)
+
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ # encode the init image into latents and scale the latents
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # 7. Prepare mask latent
+ mask = mask_image.to(device=device, dtype=latents.dtype)
+ mask = torch.cat([mask] * num_images_per_prompt)
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ # masking
+ if add_predicted_noise:
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
+ )
+ else:
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # use original latents corresponding to unmasked portions of the image
+ latents = (init_latents_orig * mask) + (latents * (1 - mask))
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..341ff8daad4299a05b89a6b5d985399c1a7fb429
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -0,0 +1,758 @@
+# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.5,
+ image_guidance_scale: float = 1.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also
+ accpet image latents as `image`, if passing latents directly, it will not be encoded again.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. This pipeline requires a value of at least `1`.
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
+ Image guidance scale is to push the generated image towards the inital image `image`. Image guidance
+ scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to
+ generate images that are closely linked to the source image `image`, usually at the expense of lower
+ image quality. This pipeline requires a value of at least `1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInstructPix2PixPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
+
+ >>> image = download_image(img_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "make the mountains snowy"
+ >>> image = pipe(prompt=prompt, image=image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Check inputs
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
+ # check if scheduler is in sigmas space
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
+
+ # 2. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare Image latents
+ image_latents = self.prepare_image_latents(
+ image,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ do_classifier_free_guidance,
+ generator,
+ )
+
+ height, width = image_latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that shapes of latents and image match the UNet channels
+ num_channels_image = image_latents.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance.
+ # The latents are expanded 3 times because for pix2pix the guidance\
+ # is applied for both the text and the input image.
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
+
+ # concat latents, image_latents in the channel dimension
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False
+ )[0]
+
+ # Hack:
+ # For karras style schedulers the model does classifer free guidance using the
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
+ # predicted_original_sample here if we are using a karras style scheduler.
+ if scheduler_is_in_sigma_space:
+ step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
+ sigma = self.scheduler.sigmas[step_index]
+ noise_pred = latent_model_input - sigma * noise_pred
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + guidance_scale * (noise_pred_text - noise_pred_image)
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
+ )
+
+ # Hack:
+ # For karras style schedulers the model does classifer free guidance using the
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
+ # need to overwrite the noise_pred here such that the value of the computed
+ # predicted_original_sample is correct.
+ if scheduler_is_in_sigma_space:
+ noise_pred = (noise_pred - latents) / (-sigma)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(
+ self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_image_latents(
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.mode()
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_image_latents = torch.zeros_like(image_latents)
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
+
+ return image_latents
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..29a57470a341f4cf1a155af9e3e023091f9e55e8
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -0,0 +1,600 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...pipelines import DiffusionPipeline
+from ...schedulers import LMSDiscreteScheduler
+from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ModelWrapper:
+ def __init__(self, model, alphas_cumprod):
+ self.model = model
+ self.alphas_cumprod = alphas_cumprod
+
+ def apply_model(self, *args, **kwargs):
+ if len(args) == 3:
+ encoder_hidden_states = args[-1]
+ args = args[:2]
+ if kwargs.get("cond", None) is not None:
+ encoder_hidden_states = kwargs.pop("cond")
+ return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
+
+
+class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+
+
+ This is an experimental pipeline and is likely to change in the future.
+
+
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ logger.info(
+ f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use"
+ " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines"
+ " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for"
+ " production settings."
+ )
+
+ # get correct sigmas from LMS
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ model = ModelWrapper(unet, scheduler.alphas_cumprod)
+ if scheduler.config.prediction_type == "v_prediction":
+ self.k_diffusion_model = CompVisVDenoiser(model)
+ else:
+ self.k_diffusion_model = CompVisDenoiser(model)
+
+ def set_scheduler(self, scheduler_type: str):
+ library = importlib.import_module("k_diffusion")
+ sampling = getattr(library, "sampling")
+ self.sampler = getattr(sampling, scheduler_type)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ use_karras_sigmas: Optional[bool] = False,
+ noise_sampler_seed: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
+ `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
+ Karras`.
+ noise_sampler_seed (`int`, *optional*, defaults to `None`):
+ The random seed to use for the noise sampler. If `None`, a random seed will be generated.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = True
+ if guidance_scale <= 1.0:
+ raise ValueError("has to use guidance_scale")
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
+
+ # 5. Prepare sigmas
+ if use_karras_sigmas:
+ sigma_min: float = self.k_diffusion_model.sigmas[0].item()
+ sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
+ sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
+ sigmas = sigmas.to(device)
+ else:
+ sigmas = self.scheduler.sigmas
+ sigmas = sigmas.to(prompt_embeds.dtype)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents = latents * sigmas[0]
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
+
+ # 7. Define model function
+ def model_fn(x, t):
+ latent_model_input = torch.cat([x] * 2)
+ t = torch.cat([t] * 2)
+
+ noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds)
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ return noise_pred
+
+ # 8. Run k-diffusion solver
+ sampler_kwargs = {}
+
+ if "noise_sampler" in inspect.signature(self.sampler).parameters:
+ min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
+ sampler_kwargs["noise_sampler"] = noise_sampler
+
+ latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d5022c854a8f17dfbc09221177763a34f9d8d36
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -0,0 +1,503 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
+
+ image = [np.array(i.resize((w, h)))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
+ r"""
+ Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`EulerDiscreteScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: EulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
+
+ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_length=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_encoder_out = self.text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+ text_embeddings = text_encoder_out.hidden_states[-1]
+ text_pooler_out = text_encoder_out.pooler_output
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_length=True,
+ return_tensors="pt",
+ )
+
+ uncond_encoder_out = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ uncond_embeddings = uncond_encoder_out.hidden_states[-1]
+ uncond_pooler_out = uncond_encoder_out.pooler_output
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
+
+ return text_embeddings, text_pooler_out
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(self, prompt, image, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ # verify batch size of prompt and image are same if image is a list or tensor
+ if isinstance(image, list) or isinstance(image, torch.Tensor):
+ if isinstance(prompt, str):
+ batch_size = 1
+ else:
+ batch_size = len(prompt)
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ else:
+ image_batch_size = image.shape[0] if image.ndim == 4 else 1
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
+ " Please make sure that passed `prompt` matches the batch size of `image`."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height, width)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image upscaling.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be
+ either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It
+ will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an
+ image representation and encoded using this pipeline's `vae` encoder.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+ ```py
+ >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
+ >>> import torch
+
+
+ >>> pipeline = StableDiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
+ ... )
+ >>> pipeline.to("cuda")
+
+ >>> model_id = "stabilityai/sd-x2-latent-upscaler"
+ >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+ >>> upscaler.to("cuda")
+
+ >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic"
+ >>> generator = torch.manual_seed(33)
+
+ >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images
+
+ >>> with torch.no_grad():
+ ... image = pipeline.decode_latents(low_res_latents)
+ >>> image = pipeline.numpy_to_pil(image)[0]
+
+ >>> image.save("../images/a1.png")
+
+ >>> upscaled_image = upscaler(
+ ... prompt=prompt,
+ ... image=low_res_latents,
+ ... num_inference_steps=20,
+ ... guidance_scale=0,
+ ... generator=generator,
+ ... ).images[0]
+
+ >>> upscaled_image.save("../images/a2.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(prompt, image, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if guidance_scale == 0:
+ prompt = [""] * batch_size
+
+ # 3. Encode input prompt
+ text_embeddings, text_pooler_out = self._encode_prompt(
+ prompt, device, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+ image = image.to(dtype=text_embeddings.dtype, device=device)
+ if image.shape[1] == 3:
+ # encode image if not in latent-space yet
+ image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
+ image = image[None, :] if image.ndim == 3 else image
+ image = torch.cat([image] * batch_multiplier)
+
+ # 5. Add noise to image (set to be 0):
+ # (see below notes from the author):
+ # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default."
+ noise_level = torch.tensor([0.0], dtype=torch.float32, device=device)
+ noise_level = torch.cat([noise_level] * image.shape[0])
+ inv_noise_level = (noise_level**2 + 1) ** (-0.5)
+
+ image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
+ image_cond = image_cond.to(text_embeddings.dtype)
+
+ noise_level_embed = torch.cat(
+ [
+ torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
+ torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
+ ],
+ dim=1,
+ )
+
+ timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
+
+ # 6. Prepare latent variables
+ height, width = image.shape[2:]
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height * 2, # 2x upscale
+ width * 2,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that sizes of image and latents match
+ num_channels_image = image.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 9. Denoising loop
+ num_warmup_steps = 0
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ sigma = self.scheduler.sigmas[i]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1)
+ # preconditioning parameter based on Karras et al. (2022) (table 1)
+ timestep = torch.log(sigma) * 0.25
+
+ noise_pred = self.unet(
+ scaled_model_input,
+ timestep,
+ encoder_hidden_states=text_embeddings,
+ timestep_cond=timestep_condition,
+ ).sample
+
+ # in original repo, the output contains a variance channel that's not used
+ noise_pred = noise_pred[:, :-1]
+
+ # apply preconditioning, based on table 1 in Karras et al. (2022)
+ inv_sigma = 1 / (sigma**2 + 1)
+ noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..95dd207f9d12b48a38b1934e46d42dc6dc03e1d5
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
@@ -0,0 +1,675 @@
+# Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessorLDM3D
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ BaseOutput,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> output = pipe(prompt)
+ >>> rgb_image, depth_image = output.rgb, output.depth
+ ```
+"""
+
+
+@dataclass
+class LDM3DPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ rgb: Union[List[PIL.Image.Image], np.ndarray]
+ depth: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+class StableDiffusionLDM3DPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D:
+ https://arxiv.org/abs/2305.10853
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode rgb and depth images to and from latent
+ representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded rgb and depth latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ rgb_feature_extractor_input = feature_extractor_input[0]
+ safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 49,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return ((rgb, depth), has_nsfw_concept)
+
+ return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ecb3f9dbaf73b848552cd4aed386f40b59a1f87
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
@@ -0,0 +1,770 @@
+# Copyright 2023 TIME Authors and The HuggingFace Team. All rights reserved."
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import PNDMScheduler
+from ...schedulers.scheduling_utils import SchedulerMixin
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+AUGS_CONST = ["A photo of ", "An image of ", "A picture of "]
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionModelEditingPipeline
+
+ >>> model_ckpt = "CompVis/stable-diffusion-v1-4"
+ >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt)
+
+ >>> pipe = pipe.to("cuda")
+
+ >>> source_prompt = "A pack of roses"
+ >>> destination_prompt = "A pack of blue roses"
+ >>> pipe.edit_model(source_prompt, destination_prompt)
+
+ >>> prompt = "A field of roses"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ with_to_k ([`bool`]):
+ Whether to edit the key projection matrices along wiht the value projection matrices.
+ with_augs ([`list`]):
+ Textual augmentations to apply while editing the text-to-image model. Set to [] for no augmentations.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: SchedulerMixin,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ with_to_k: bool = True,
+ with_augs: list = AUGS_CONST,
+ ):
+ super().__init__()
+
+ if isinstance(scheduler, PNDMScheduler):
+ logger.error("PNDMScheduler for this pipeline is currently not supported.")
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ self.with_to_k = with_to_k
+ self.with_augs = with_augs
+
+ # get cross-attention layers
+ ca_layers = []
+
+ def append_ca(net_):
+ if net_.__class__.__name__ == "CrossAttention":
+ ca_layers.append(net_)
+ elif hasattr(net_, "children"):
+ for net__ in net_.children():
+ append_ca(net__)
+
+ # recursively find all cross-attention layers in unet
+ for net in self.unet.named_children():
+ if "down" in net[0]:
+ append_ca(net[1])
+ elif "up" in net[0]:
+ append_ca(net[1])
+ elif "mid" in net[0]:
+ append_ca(net[1])
+
+ # get projection matrices
+ self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768]
+ self.projection_matrices = [l.to_v for l in self.ca_clip_layers]
+ self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers]
+ if self.with_to_k:
+ self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers]
+ self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers]
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def edit_model(
+ self,
+ source_prompt: str,
+ destination_prompt: str,
+ lamb: float = 0.1,
+ restart_params: bool = True,
+ ):
+ r"""
+ Apply model editing via closed-form solution (see Eq. 5 in the TIME paper https://arxiv.org/abs/2303.08084)
+
+ Args:
+ source_prompt (`str`):
+ The source prompt containing the concept to be edited.
+ destination_prompt (`str`):
+ The destination prompt. Must contain all words from source_prompt with additional ones to specify the
+ target edit.
+ lamb (`float`, *optional*, defaults to 0.1):
+ The lambda parameter specifying the regularization intesity. Smaller values increase the editing power.
+ restart_params (`bool`, *optional*, defaults to True):
+ Restart the model parameters to their pre-trained version before editing. This is done to avoid edit
+ compounding. When it is False, edits accumulate.
+ """
+
+ # restart LDM parameters
+ if restart_params:
+ num_ca_clip_layers = len(self.ca_clip_layers)
+ for idx_, l in enumerate(self.ca_clip_layers):
+ l.to_v = copy.deepcopy(self.og_matrices[idx_])
+ self.projection_matrices[idx_] = l.to_v
+ if self.with_to_k:
+ l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_])
+ self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k
+
+ # set up sentences
+ old_texts = [source_prompt]
+ new_texts = [destination_prompt]
+ # add augmentations
+ base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:]
+ for aug in self.with_augs:
+ old_texts.append(aug + base)
+ base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:]
+ for aug in self.with_augs:
+ new_texts.append(aug + base)
+
+ # prepare input k* and v*
+ old_embs, new_embs = [], []
+ for old_text, new_text in zip(old_texts, new_texts):
+ text_input = self.tokenizer(
+ [old_text, new_text],
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ old_emb, new_emb = text_embeddings
+ old_embs.append(old_emb)
+ new_embs.append(new_emb)
+
+ # identify corresponding destinations for each token in old_emb
+ idxs_replaces = []
+ for old_text, new_text in zip(old_texts, new_texts):
+ tokens_a = self.tokenizer(old_text).input_ids
+ tokens_b = self.tokenizer(new_text).input_ids
+ tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a]
+ tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b]
+ num_orig_tokens = len(tokens_a)
+ idxs_replace = []
+ j = 0
+ for i in range(num_orig_tokens):
+ curr_token = tokens_a[i]
+ while tokens_b[j] != curr_token:
+ j += 1
+ idxs_replace.append(j)
+ j += 1
+ while j < 77:
+ idxs_replace.append(j)
+ j += 1
+ while len(idxs_replace) < 77:
+ idxs_replace.append(76)
+ idxs_replaces.append(idxs_replace)
+
+ # prepare batch: for each pair of setences, old context and new values
+ contexts, valuess = [], []
+ for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
+ context = old_emb.detach()
+ values = []
+ with torch.no_grad():
+ for layer in self.projection_matrices:
+ values.append(layer(new_emb[idxs_replace]).detach())
+ contexts.append(context)
+ valuess.append(values)
+
+ # edit the model
+ for layer_num in range(len(self.projection_matrices)):
+ # mat1 = \lambda W + \sum{v k^T}
+ mat1 = lamb * self.projection_matrices[layer_num].weight
+
+ # mat2 = \lambda I + \sum{k k^T}
+ mat2 = lamb * torch.eye(
+ self.projection_matrices[layer_num].weight.shape[1],
+ device=self.projection_matrices[layer_num].weight.device,
+ )
+
+ # aggregate sums for mat1, mat2
+ for context, values in zip(contexts, valuess):
+ context_vector = context.reshape(context.shape[0], context.shape[1], 1)
+ context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
+ value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1)
+ for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
+ for_mat2 = (context_vector @ context_vector_T).sum(dim=0)
+ mat1 += for_mat1
+ mat2 += for_mat2
+
+ # update projection matrix
+ self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2))
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e705d1bc5a822a5ae6242139123f3a14b76a68
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -0,0 +1,739 @@
+# Copyright 2023 MultiDiffusion Authors and The HuggingFace Team. All rights reserved."
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import DDIMScheduler
+from ...utils import logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
+
+ >>> model_ckpt = "stabilityai/stable-diffusion-2-base"
+ >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
+ >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained(
+ ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
+ ... )
+
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of the dolomites"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
+ Generation".
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
+
+ To generate panorama-like images, be sure to pass the `width` parameter accordingly when using the pipeline. Our
+ recommendation for the `width` value is 2048. This is the default value of the `width` parameter for this pipeline.
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. The original work
+ on Multi Diffsion used the [`DDIMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def decode_latents_with_padding(self, latents, padding=8):
+ # Add padding to latents for circular inference
+ # padding is the number of latents to add on each side
+ # it would slightly increase the memory usage, but remove the boundary artifacts
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents_left = latents[..., :padding]
+ latents_right = latents[..., -padding:]
+ latents = torch.cat((latents_right, latents, latents_left), axis=-1)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ padding_pix = self.vae_scale_factor * padding
+ image = image[..., padding_pix:-padding_pix]
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
+ # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
+ # if panorama's height/width < window_size, num_blocks of height/width should return 1
+ panorama_height /= 8
+ panorama_width /= 8
+ num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
+ if circular_padding:
+ num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
+ else:
+ num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * stride)
+ h_end = h_start + window_size
+ w_start = int((i % num_blocks_width) * stride)
+ w_end = w_start + window_size
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 2048,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ view_batch_size: int = 1,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ circular_padding: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to 512:
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 2048):
+ The width in pixels of the generated image. The width is kept to a high number because the
+ pipeline is supposed to be used for generating panorama-like images.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ view_batch_size (`int`, *optional*, defaults to 1):
+ The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
+ can speedup the generation and increase the VRAM usage.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ circular_padding (`bool`, *optional*, defaults to `False`):
+ If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
+ padding allows the model to seamlessly generate a transition from the rightmost part of the image to
+ the leftmost part, maintaining consistency in a 360-degree sense.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Define panorama grid and initialize views for synthesis.
+ # prepare batch grid
+ views = self.get_views(height, width, circular_padding=circular_padding)
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
+ count = torch.zeros_like(latents)
+ value = torch.zeros_like(latents)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ # Each denoising step also includes refinement of the latents with respect to the
+ # views.
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ count.zero_()
+ value.zero_()
+
+ # generate views
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
+ # denoised (latent) crops are then averaged to produce the final latent
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
+ # Batch views denoise
+ for j, batch_view in enumerate(views_batch):
+ vb_size = len(batch_view)
+ # get the latents corresponding to the current view coordinates
+ if circular_padding:
+ latents_for_view = []
+ for h_start, h_end, w_start, w_end in batch_view:
+ if w_end > latents.shape[3]:
+ # Add circular horizontal padding
+ latent_view = torch.cat(
+ (
+ latents[:, :, h_start:h_end, w_start:],
+ latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
+ ),
+ axis=-1,
+ )
+ else:
+ latent_view = latents[:, :, h_start:h_end, w_start:w_end]
+ latents_for_view.append(latent_view)
+ latents_for_view = torch.cat(latents_for_view)
+ else:
+ latents_for_view = torch.cat(
+ [
+ latents[:, :, h_start:h_end, w_start:w_end]
+ for h_start, h_end, w_start, w_end in batch_view
+ ]
+ )
+
+ # rematch block's scheduler status
+ self.scheduler.__dict__.update(views_scheduler_status[j])
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ latents_for_view.repeat_interleave(2, dim=0)
+ if do_classifier_free_guidance
+ else latents_for_view
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # repeat prompt_embeds for batch
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_input,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_denoised_batch = self.scheduler.step(
+ noise_pred, t, latents_for_view, **extra_step_kwargs
+ ).prev_sample
+
+ # save views scheduler status after sample
+ views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
+
+ # extract value from batch
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
+ latents_denoised_batch.chunk(vb_size), batch_view
+ ):
+ if circular_padding and w_end > latents.shape[3]:
+ # Case for circular padding
+ value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
+ :, :, h_start:h_end, : latents.shape[3] - w_start
+ ]
+ value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
+ :, :, h_start:h_end, latents.shape[3] - w_start :
+ ]
+ count[:, :, h_start:h_end, w_start:] += 1
+ count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
+ else:
+ value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
+ count[:, :, h_start:h_end, w_start:w_end] += 1
+
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
+ latents = torch.where(count > 0, value / count, value)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ if circular_padding:
+ image = self.decode_latents_with_padding(latents)
+ else:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py
new file mode 100644
index 0000000000000000000000000000000000000000..073f02e8ee98a519f85b4437a46557f0e0f00c7b
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py
@@ -0,0 +1,787 @@
+# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DDPMParallelScheduler
+ >>> from diffusers import StableDiffusionParadigmsPipeline
+
+ >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
+
+ >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> ngpu, batch_per_device = torch.cuda.device_count(), 5
+ >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]
+ ```
+"""
+
+
+class StableDiffusionParadigmsPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline
+ parallelizes the denoising steps to generate a single image faster (more akin to model parallelism).
+
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs
+ self.wrapped_unet = self.unet
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _cumsum(self, input, dim, debug=False):
+ if debug:
+ # cumsum_cuda_kernel does not have a deterministic implementation
+ # so perform cumsum on cpu for debugging purposes
+ return torch.cumsum(input.cpu().float(), dim=dim).to(input.device)
+ else:
+ return torch.cumsum(input, dim=dim)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ parallel: int = 10,
+ tolerance: float = 0.1,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ debug: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ parallel (`int`, *optional*, defaults to 10):
+ The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but
+ requires higher memory usage and also can require more total FLOPs.
+ tolerance (`float`, *optional*, defaults to 0.1):
+ The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower
+ tolerance usually leads to less/no degradation. Higher tolerance is faster but can risk degradation of
+ sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ debug (`bool`, *optional*, defaults to `False`):
+ Whether or not to run in debug mode. In debug mode, torch.cumsum is evaluated using the CPU.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ extra_step_kwargs.pop("generator", None)
+
+ # # 7. Denoising loop
+ scheduler = self.scheduler
+ parallel = min(parallel, len(scheduler.timesteps))
+
+ begin_idx = 0
+ end_idx = parallel
+ latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1))
+
+ # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep.
+ # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop.
+ noise_array = torch.zeros_like(latents_time_evolution_buffer)
+ for j in range(len(scheduler.timesteps)):
+ base_noise = randn_tensor(
+ shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype
+ )
+ noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise
+ noise_array[j] = noise.clone()
+
+ # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance
+ # outside of the denoising loop to avoid recomputing it at every step.
+ # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step.
+ inverse_variance_norm = 1.0 / torch.tensor(
+ [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0]
+ ).to(noise_array.device)
+ latent_dim = noise_array[0, 0].numel()
+ inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim
+
+ scaled_tolerance = tolerance**2
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ steps = 0
+ while begin_idx < len(scheduler.timesteps):
+ # these have shape (parallel_dim, 2*batch_size, ...)
+ # parallel_len is at most parallel, but could be less if we are at the end of the timesteps
+ # we are processing batch window of timesteps spanning [begin_idx, end_idx)
+ parallel_len = end_idx - begin_idx
+
+ block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len)
+ block_latents = latents_time_evolution_buffer[begin_idx:end_idx]
+ block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt)
+ t_vec = block_t
+ if do_classifier_free_guidance:
+ t_vec = t_vec.repeat(1, 2)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec)
+
+ # if parallel_len is small, no need to use multiple GPUs
+ net = self.wrapped_unet if parallel_len > 3 else self.unet
+ # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...]
+ model_output = net(
+ latent_model_input.flatten(0, 1),
+ t_vec.flatten(0, 1),
+ encoder_hidden_states=block_prompt_embeds.flatten(0, 1),
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ per_latent_shape = model_output.shape[1:]
+ if do_classifier_free_guidance:
+ model_output = model_output.reshape(
+ parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape
+ )
+ noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1]
+ model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ model_output = model_output.reshape(
+ parallel_len * batch_size * num_images_per_prompt, *per_latent_shape
+ )
+
+ block_latents_denoise = scheduler.batch_step_no_noise(
+ model_output=model_output,
+ timesteps=block_t.flatten(0, 1),
+ sample=block_latents.flatten(0, 1),
+ **extra_step_kwargs,
+ ).reshape(block_latents.shape)
+
+ # back to shape (parallel_dim, batch_size, ...)
+ # now we want to add the pre-sampled noise
+ # parallel sampling algorithm requires computing the cumulative drift from the beginning
+ # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises.
+ delta = block_latents_denoise - block_latents
+ cumulative_delta = self._cumsum(delta, dim=0, debug=debug)
+ cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug)
+
+ # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise
+ if scheduler._is_ode_scheduler:
+ cumulative_noise = 0
+
+ block_latents_new = (
+ latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise
+ )
+ cur_error = torch.linalg.norm(
+ (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape(
+ parallel_len, batch_size * num_images_per_prompt, -1
+ ),
+ dim=-1,
+ ).pow(2)
+ error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1]
+
+ # find the first index of the vector error_ratio that is greater than error tolerance
+ # we can shift the window for the next iteration up to this index
+ error_ratio = torch.nn.functional.pad(
+ error_ratio, (0, 0, 0, 1), value=1e9
+ ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension
+ any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int()
+ ind = torch.argmax(any_error_at_time).item()
+
+ # compute the new begin and end idxs for the window
+ new_begin_idx = begin_idx + min(1 + ind, parallel)
+ new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps))
+
+ # store the computed latents for the current window in the global buffer
+ latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new
+ # initialize the new sliding window latents with the end of the current window,
+ # should be better than random initialization
+ latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][
+ None,
+ ]
+
+ steps += 1
+
+ progress_bar.update(new_begin_idx - begin_idx)
+ if callback is not None and steps % callback_steps == 0:
+ callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx])
+
+ begin_idx = new_begin_idx
+ end_idx = new_end_idx
+
+ latents = latents_time_evolution_buffer[-1]
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
new file mode 100644
index 0000000000000000000000000000000000000000..960c4369e45ab8792e90c8aaa253f52ac65cef5c
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
@@ -0,0 +1,1259 @@
+# Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import (
+ BlipForConditionalGeneration,
+ BlipProcessor,
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import Attention
+from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
+from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
+from ...utils import (
+ PIL_INTERPOLATION,
+ BaseOutput,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ latents (`torch.FloatTensor`)
+ inverted latents tensor
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ latents: torch.FloatTensor
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+
+ >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
+
+
+ >>> def download(embedding_url, local_filepath):
+ ... r = requests.get(embedding_url)
+ ... with open(local_filepath, "wb") as f:
+ ... f.write(r.content)
+
+
+ >>> model_ckpt = "CompVis/stable-diffusion-v1-4"
+ >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
+ >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.to("cuda")
+
+ >>> prompt = "a high resolution painting of a cat in the style of van gough"
+ >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
+ >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
+
+ >>> for url in [source_emb_url, target_emb_url]:
+ ... download(url, url.split("/")[-1])
+
+ >>> src_embeds = torch.load(source_emb_url.split("/")[-1])
+ >>> target_embeds = torch.load(target_emb_url.split("/")[-1])
+ >>> images = pipeline(
+ ... prompt,
+ ... source_embeds=src_embeds,
+ ... target_embeds=target_embeds,
+ ... num_inference_steps=50,
+ ... cross_attention_guidance_amount=0.15,
+ ... ).images
+
+ >>> images[0].save("edited_image_dog.png")
+ ```
+"""
+
+EXAMPLE_INVERT_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from transformers import BlipForConditionalGeneration, BlipProcessor
+ >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
+
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> captioner_id = "Salesforce/blip-image-captioning-base"
+ >>> processor = BlipProcessor.from_pretrained(captioner_id)
+ >>> model = BlipForConditionalGeneration.from_pretrained(
+ ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
+ ... )
+
+ >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
+ >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
+ ... sd_model_ckpt,
+ ... caption_generator=model,
+ ... caption_processor=processor,
+ ... torch_dtype=torch.float16,
+ ... safety_checker=None,
+ ... )
+
+ >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.enable_model_cpu_offload()
+
+ >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
+
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
+ >>> # generate caption
+ >>> caption = pipeline.generate_caption(raw_image)
+
+ >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii"
+ >>> inv_latents = pipeline.invert(caption, image=raw_image).latents
+ >>> # we need to generate source and target embeds
+
+ >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
+
+ >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
+
+ >>> source_embeds = pipeline.get_embeds(source_prompts)
+ >>> target_embeds = pipeline.get_embeds(target_prompts)
+ >>> # the latents can then be used to edit a real image
+ >>> # when using Stable Diffusion 2 or other models that use v-prediction
+ >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion
+
+ >>> image = pipeline(
+ ... caption,
+ ... source_embeds=source_embeds,
+ ... target_embeds=target_embeds,
+ ... num_inference_steps=50,
+ ... cross_attention_guidance_amount=0.15,
+ ... generator=generator,
+ ... latents=inv_latents,
+ ... negative_prompt=caption,
+ ... ).images[0]
+ >>> image.save("edited_image.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def prepare_unet(unet: UNet2DConditionModel):
+ """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations."""
+ pix2pix_zero_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ module_name = name.replace(".processor", "")
+ module = unet.get_submodule(module_name)
+ if "attn2" in name:
+ pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)
+ module.requires_grad_(True)
+ else:
+ pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)
+ module.requires_grad_(False)
+
+ unet.set_attn_processor(pix2pix_zero_attn_procs)
+ return unet
+
+
+class Pix2PixZeroL2Loss:
+ def __init__(self):
+ self.loss = 0.0
+
+ def compute_loss(self, predictions, targets):
+ self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)
+
+
+class Pix2PixZeroAttnProcessor:
+ """An attention processor class to store the attention weights.
+ In Pix2Pix Zero, it happens during computations in the cross-attention blocks."""
+
+ def __init__(self, is_pix2pix_zero=False):
+ self.is_pix2pix_zero = is_pix2pix_zero
+ if self.is_pix2pix_zero:
+ self.reference_cross_attn_map = {}
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ timestep=None,
+ loss=None,
+ ):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ if self.is_pix2pix_zero and timestep is not None:
+ # new bookkeeping to save the attention weights.
+ if loss is None:
+ self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu()
+ # compute loss
+ elif loss is not None:
+ prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item())
+ loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device))
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ requires_safety_checker (bool):
+ Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
+ pipeline publicly.
+ """
+ _optional_components = [
+ "safety_checker",
+ "feature_extractor",
+ "caption_generator",
+ "caption_processor",
+ "inverse_scheduler",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],
+ feature_extractor: CLIPImageProcessor,
+ safety_checker: StableDiffusionSafetyChecker,
+ inverse_scheduler: DDIMInverseScheduler,
+ caption_generator: BlipForConditionalGeneration,
+ caption_processor: BlipProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ caption_processor=caption_processor,
+ caption_generator=caption_generator,
+ inverse_scheduler=inverse_scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ source_embeds,
+ target_embeds,
+ callback_steps,
+ prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if source_embeds is None and target_embeds is None:
+ raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def generate_caption(self, images):
+ """Generates caption for a given image."""
+ text = "a photography of"
+
+ prev_device = self.caption_generator.device
+
+ device = self._execution_device
+ inputs = self.caption_processor(images, text, return_tensors="pt").to(
+ device=device, dtype=self.caption_generator.dtype
+ )
+ self.caption_generator.to(device)
+ outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)
+
+ # offload caption generator
+ self.caption_generator.to(prev_device)
+
+ caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
+ return caption
+
+ def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
+ """Constructs the edit direction to steer the image generation process semantically."""
+ return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
+
+ @torch.no_grad()
+ def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor:
+ num_prompts = len(prompt)
+ embeds = []
+ for i in range(0, num_prompts, batch_size):
+ prompt_slice = prompt[i : i + batch_size]
+
+ input_ids = self.tokenizer(
+ prompt_slice,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ input_ids = input_ids.to(self.text_encoder.device)
+ embeds.append(self.text_encoder(input_ids)[0])
+
+ return torch.cat(embeds, dim=0).mean(0)[None]
+
+ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0)
+ else:
+ latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ latents = self.vae.config.scaling_factor * latents
+
+ if batch_size != latents.shape[0]:
+ if batch_size % latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_latents_per_image = batch_size // latents.shape[0]
+ latents = torch.cat([latents] * additional_latents_per_image, dim=0)
+ else:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ latents = torch.cat([latents], dim=0)
+
+ return latents
+
+ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
+ pred_type = self.inverse_scheduler.config.prediction_type
+ alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ if pred_type == "epsilon":
+ return model_output
+ elif pred_type == "sample":
+ return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
+ elif pred_type == "v_prediction":
+ return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
+ )
+
+ def auto_corr_loss(self, hidden_states, generator=None):
+ reg_loss = 0.0
+ for i in range(hidden_states.shape[0]):
+ for j in range(hidden_states.shape[1]):
+ noise = hidden_states[i : i + 1, j : j + 1, :, :]
+ while True:
+ roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
+
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ return reg_loss
+
+ def kl_divergence(self, hidden_states):
+ mean = hidden_states.mean()
+ var = hidden_states.var()
+ return var + mean**2 - 1 - torch.log(var + 1e-7)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ source_embeds: torch.Tensor = None,
+ target_embeds: torch.Tensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ cross_attention_guidance_amount: float = 0.1,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ source_embeds (`torch.Tensor`):
+ Source concept embeddings. Generation of the embeddings as per the [original
+ paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
+ target_embeds (`torch.Tensor`):
+ Target concept embeddings. Generation of the embeddings as per the [original
+ paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ cross_attention_guidance_amount (`float`, defaults to 0.1):
+ Amount of guidance needed from the reference cross-attention maps.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Define the spatial resolutions.
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ source_embeds,
+ target_embeds,
+ callback_steps,
+ prompt_embeds,
+ )
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Generate the inverted noise from the input image or any other image
+ # generated from the input prompt.
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents_init = latents.clone()
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Rejig the UNet so that we can obtain the cross-attenion maps and
+ # use them for guiding the subsequent image generation.
+ self.unet = prepare_unet(self.unet)
+
+ # 7. Denoising loop where we obtain the cross-attention maps.
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs={"timestep": t},
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Compute the edit directions.
+ edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device)
+
+ # 9. Edit the prompt embeddings as per the edit directions discovered.
+ prompt_embeds_edit = prompt_embeds.clone()
+ prompt_embeds_edit[1:2] += edit_direction
+
+ # 10. Second denoising loop to generate the edited image.
+ latents = latents_init
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # we want to learn the latent such that it steers the generation
+ # process towards the edited direction, so make the make initial
+ # noise learnable
+ x_in = latent_model_input.detach().clone()
+ x_in.requires_grad = True
+
+ # optimizer
+ opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount)
+
+ with torch.enable_grad():
+ # initialize loss
+ loss = Pix2PixZeroL2Loss()
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ x_in,
+ t,
+ encoder_hidden_states=prompt_embeds_edit.detach(),
+ cross_attention_kwargs={"timestep": t, "loss": loss},
+ ).sample
+
+ loss.loss.backward(retain_graph=False)
+ opt.step()
+
+ # recompute the noise
+ noise_pred = self.unet(
+ x_in.detach(),
+ t,
+ encoder_hidden_states=prompt_embeds_edit,
+ cross_attention_kwargs={"timestep": None},
+ ).sample
+
+ latents = x_in.detach().chunk(2)[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
+ def invert(
+ self,
+ prompt: Optional[str] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ cross_attention_guidance_amount: float = 0.1,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ lambda_auto_corr: float = 20.0,
+ lambda_kl: float = 20.0,
+ num_reg_steps: int = 5,
+ num_auto_corr_rolls: int = 5,
+ ):
+ r"""
+ Function used to generate inverted latents given a prompt and image.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ cross_attention_guidance_amount (`float`, defaults to 0.1):
+ Amount of guidance needed from the reference cross-attention maps.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ lambda_auto_corr (`float`, *optional*, defaults to 20.0):
+ Lambda parameter to control auto correction
+ lambda_kl (`float`, *optional*, defaults to 20.0):
+ Lambda parameter to control Kullback–Leibler divergence output
+ num_reg_steps (`int`, *optional*, defaults to 5):
+ Number of regularization loss steps
+ num_auto_corr_rolls (`int`, *optional*, defaults to 5):
+ Number of auto correction roll steps
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or
+ `tuple`:
+ [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted
+ latents tensor and then second is the corresponding decoded image.
+ """
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. Prepare latent variables
+ latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
+
+ # 5. Encode input prompt
+ num_images_per_prompt = 1
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.inverse_scheduler.timesteps
+
+ # 6. Rejig the UNet so that we can obtain the cross-attenion maps and
+ # use them for guiding the subsequent image generation.
+ self.unet = prepare_unet(self.unet)
+
+ # 7. Denoising loop where we obtain the cross-attention maps.
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs={"timestep": t},
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # regularization of the noise prediction
+ with torch.enable_grad():
+ for _ in range(num_reg_steps):
+ if lambda_auto_corr > 0:
+ for _ in range(num_auto_corr_rolls):
+ var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
+
+ # Derive epsilon from model output before regularizing to IID standard normal
+ var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
+
+ l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
+ l_ac.backward()
+
+ grad = var.grad.detach() / num_auto_corr_rolls
+ noise_pred = noise_pred - lambda_auto_corr * grad
+
+ if lambda_kl > 0:
+ var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
+
+ # Derive epsilon from model output before regularizing to IID standard normal
+ var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
+
+ l_kld = self.kl_divergence(var_epsilon)
+ l_kld.backward()
+
+ grad = var.grad.detach()
+ noise_pred = noise_pred - lambda_kl * grad
+
+ noise_pred = noise_pred.detach()
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ inverted_latents = latents.detach().clone()
+
+ # 8. Post-processing
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (inverted_latents, image)
+
+ return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c583de9ca9c03713262c0e103e7f32eca9eb4f6
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -0,0 +1,767 @@
+# Copyright 2023 Susung Hong and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionSAGPipeline
+
+ >>> pipe = StableDiffusionSAGPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, sag_scale=0.75).images[0]
+ ```
+"""
+
+
+# processes and stores attention probabilities
+class CrossAttnStoreProcessor:
+ def __init__(self):
+ self.attention_probs = None
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ ):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ self.attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(self.attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
+class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ sag_scale: float = 0.75,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ sag_scale (`float`, *optional*, defaults to 0.75):
+ SAG scale as defined in [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance]
+ (https://arxiv.org/abs/2210.00939). `sag_scale` is defined as `s_s` of equation (24) of SAG paper:
+ https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # and `sag_scale` is` `s` of equation (16)
+ # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf
+ # `sag_scale = 0` means no self-attention guidance
+ do_self_attention_guidance = sag_scale > 0.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ store_processor = CrossAttnStoreProcessor()
+ self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ map_size = None
+
+ def get_map_size(module, input, output):
+ nonlocal map_size
+ map_size = output[0].shape[-2:]
+
+ with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform self-attention guidance with the stored self-attentnion map
+ if do_self_attention_guidance:
+ # classifier-free guidance produces two chunks of attention map
+ # and we only use unconditional one according to equation (25)
+ # in https://arxiv.org/pdf/2210.00939.pdf
+ if do_classifier_free_guidance:
+ # DDIM-like prediction of x0
+ pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
+ # get the stored attention maps
+ uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
+ # self-attention-based degrading of latents
+ degraded_latents = self.sag_masking(
+ pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
+ )
+ uncond_emb, _ = prompt_embeds.chunk(2)
+ # forward and give guidance
+ degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
+ noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
+ else:
+ # DDIM-like prediction of x0
+ pred_x0 = self.pred_x0(latents, noise_pred, t)
+ # get the stored attention maps
+ cond_attn = store_processor.attention_probs
+ # self-attention-based degrading of latents
+ degraded_latents = self.sag_masking(
+ pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
+ )
+ # forward and give guidance
+ degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
+ noise_pred += sag_scale * (noise_pred - degraded_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def sag_masking(self, original_latents, attn_map, map_size, t, eps):
+ # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
+ bh, hw1, hw2 = attn_map.shape
+ b, latent_channel, latent_h, latent_w = original_latents.shape
+ h = self.unet.config.attention_head_dim
+ if isinstance(h, list):
+ h = h[-1]
+
+ # Produce attention mask
+ attn_map = attn_map.reshape(b, h, hw1, hw2)
+ attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
+ attn_mask = (
+ attn_mask.reshape(b, map_size[0], map_size[1])
+ .unsqueeze(1)
+ .repeat(1, latent_channel, 1, 1)
+ .type(attn_map.dtype)
+ )
+ attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
+
+ # Blur according to the self-attention mask
+ degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
+ degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
+
+ # Noise it again to match the noise level
+ degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t)
+
+ return degraded_latents
+
+ # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
+ # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.)
+ def pred_x0(self, sample, model_output, timestep):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+
+ beta_prod_t = 1 - alpha_prod_t
+ if self.scheduler.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.scheduler.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.scheduler.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # predict V
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " or `v_prediction`"
+ )
+
+ return pred_original_sample
+
+ def pred_epsilon(self, sample, model_output, timestep):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+
+ beta_prod_t = 1 - alpha_prod_t
+ if self.scheduler.config.prediction_type == "epsilon":
+ pred_eps = model_output
+ elif self.scheduler.config.prediction_type == "sample":
+ pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5)
+ elif self.scheduler.config.prediction_type == "v_prediction":
+ pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " or `v_prediction`"
+ )
+
+ return pred_eps
+
+
+# Gaussian blur
+def gaussian_blur_2d(img, kernel_size, sigma):
+ ksize_half = (kernel_size - 1) * 0.5
+
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
+
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
+
+ x_kernel = pdf / pdf.sum()
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
+
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
+
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
+
+ img = F.pad(img, padding, mode="reflect")
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
+
+ return img
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..61b1419d5ced4e1c90d6ec2dcecadafce38b2d9e
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -0,0 +1,762 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
+from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
+
+ image = [np.array(i.resize((w, h)))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-guided image super-resolution using Stable Diffusion 2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ low_res_scheduler ([`SchedulerMixin`]):
+ A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
+ [`DDPMScheduler`].
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ _optional_components = ["watermarker", "safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: Optional[Any] = None,
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ watermarker: Optional[Any] = None,
+ max_noise_level: int = 350,
+ ):
+ super().__init__()
+
+ if hasattr(
+ vae, "config"
+ ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
+ is_vae_scaling_factor_set_to_0_08333 = (
+ hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
+ )
+ if not is_vae_scaling_factor_set_to_0_08333:
+ deprecation_message = (
+ "The configuration file of the vae does not contain `scaling_factor` or it is set to"
+ f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
+ " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to"
+ " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging"
+ " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file"
+ )
+ deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
+ vae.register_to_config(scaling_factor=0.08333)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ watermarker=watermarker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
+ self.register_to_config(max_noise_level=max_noise_level)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, nsfw_detected, watermark_detected = self.safety_checker(
+ images=image,
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
+ )
+ else:
+ nsfw_detected = None
+ watermark_detected = None
+
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
+
+ return image, nsfw_detected, watermark_detected
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ noise_level,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, np.ndarray)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
+ if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
+ if isinstance(prompt, str):
+ batch_size = 1
+ else:
+ batch_size = len(prompt)
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ else:
+ image_batch_size = image.shape[0]
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
+ " Please make sure that passed `prompt` matches the batch size of `image`."
+ )
+
+ # check noise level
+ if noise_level > self.config.max_noise_level:
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height, width)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 20,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+ ```py
+ >>> import requests
+ >>> from PIL import Image
+ >>> from io import BytesIO
+ >>> from diffusers import StableDiffusionUpscalePipeline
+ >>> import torch
+
+ >>> # load model and scheduler
+ >>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
+ >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
+ ... model_id, revision="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipeline = pipeline.to("cuda")
+
+ >>> # let's download an image
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
+ >>> response = requests.get(url)
+ >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> low_res_img = low_res_img.resize((128, 128))
+ >>> prompt = "a white cat"
+
+ >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
+ >>> upscaled_image.save("upsampled_cat.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ image,
+ noise_level,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Add noise to image
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+ image = self.low_res_scheduler.add_noise(image, noise, noise_level)
+
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
+ image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
+ noise_level = torch.cat([noise_level] * image.shape[0])
+
+ # 6. Prepare latent variables
+ height, width = image.shape[2:]
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that sizes of image and latents match
+ num_channels_image = image.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, image], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=noise_level,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 10. Post-processing
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # 11. Apply watermark
+ if output_type == "pil" and self.watermarker is not None:
+ image = self.watermarker.apply_watermark(image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c89bfedbd5903409d33c19cab47b69cd6dcc926
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -0,0 +1,914 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
+from ...models.embeddings import get_timestep_embedding
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableUnCLIPPipeline
+
+ >>> pipe = StableUnCLIPPipeline.from_pretrained(
+ ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16
+ ... ) # TODO update model path
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> images = pipe(prompt).images
+ >>> images[0].save("astronaut_horse.png")
+ ```
+"""
+
+
+class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ """
+ Pipeline for text-to-image generation using stable unCLIP.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior_tokenizer ([`CLIPTokenizer`]):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ prior_text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ prior_scheduler ([`KarrasDiffusionSchedulers`]):
+ Scheduler used in the prior denoising process.
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
+ embeddings after the noise has been applied.
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`KarrasDiffusionSchedulers`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ """
+
+ _exclude_from_cpu_offload = ["prior", "image_normalizer"]
+
+ # prior components
+ prior_tokenizer: CLIPTokenizer
+ prior_text_encoder: CLIPTextModelWithProjection
+ prior: PriorTransformer
+ prior_scheduler: KarrasDiffusionSchedulers
+
+ # image noising components
+ image_normalizer: StableUnCLIPImageNormalizer
+ image_noising_scheduler: KarrasDiffusionSchedulers
+
+ # regular denoising components
+ tokenizer: CLIPTokenizer
+ text_encoder: CLIPTextModel
+ unet: UNet2DConditionModel
+ scheduler: KarrasDiffusionSchedulers
+
+ vae: AutoencoderKL
+
+ def __init__(
+ self,
+ # prior components
+ prior_tokenizer: CLIPTokenizer,
+ prior_text_encoder: CLIPTextModelWithProjection,
+ prior: PriorTransformer,
+ prior_scheduler: KarrasDiffusionSchedulers,
+ # image noising components
+ image_normalizer: StableUnCLIPImageNormalizer,
+ image_noising_scheduler: KarrasDiffusionSchedulers,
+ # regular denoising components
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ # vae
+ vae: AutoencoderKL,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_encoder,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder
+ def _encode_prior_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.prior_tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.prior_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.prior_tokenizer.batch_decode(
+ untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length]
+
+ prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = prior_text_encoder_output.text_embeds
+ prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.prior_tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.prior_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder(
+ uncond_input.input_ids.to(device)
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
+ uncond_prior_text_encoder_hidden_states = (
+ negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
+ )
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prior_text_encoder_hidden_states = torch.cat(
+ [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
+ )
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, prior_text_encoder_hidden_states, text_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
+ def prepare_prior_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the prior_scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ noise_level,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two."
+ )
+
+ if prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
+ )
+
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
+ raise ValueError(
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def noise_image_embeddings(
+ self,
+ image_embeds: torch.Tensor,
+ noise_level: int,
+ noise: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ):
+ """
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
+ `noise_level` increases the variance in the final un-noised images.
+
+ The noise is applied in two ways
+ 1. A noise schedule is applied directly to the embeddings
+ 2. A vector of sinusoidal time embeddings are appended to the output.
+
+ In both cases, the amount of noise is controlled by the same `noise_level`.
+
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
+ """
+ if noise is None:
+ noise = randn_tensor(
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
+ )
+
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
+
+ self.image_normalizer.to(image_embeds.device)
+ image_embeds = self.image_normalizer.scale(image_embeds)
+
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
+
+ image_embeds = self.image_normalizer.unscale(image_embeds)
+
+ noise_level = get_timestep_embedding(
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
+ )
+
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
+ # but we might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ noise_level = noise_level.to(image_embeds.dtype)
+
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
+
+ return image_embeds
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ # regular denoising process args
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 20,
+ guidance_scale: float = 10.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 0,
+ # prior args
+ prior_num_inference_steps: int = 25,
+ prior_guidance_scale: float = 4.0,
+ prior_latents: Optional[torch.FloatTensor] = None,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 10.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to `0`):
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps in the prior denoising process. More denoising steps usually lead to a
+ higher quality image at the expense of slower inference.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion
+ Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of
+ [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ prior_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ embedding generation in the prior denoising process. Can be used to tweak the same generation with
+ different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied
+ random `generator`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_steps=callback_steps,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ batch_size = batch_size * num_images_per_prompt
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ prior_do_classifier_free_guidance = prior_guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=prior_do_classifier_free_guidance,
+ )
+
+ # 4. Prepare prior timesteps
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ # 5. Prepare prior latent variables
+ embedding_dim = self.prior.config.embedding_dim
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prior_prompt_embeds.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta)
+
+ # 7. Prior denoising loop
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents
+ latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t)
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prior_prompt_embeds,
+ encoder_hidden_states=prior_text_encoder_hidden_states,
+ attention_mask=prior_text_mask,
+ ).predicted_image_embedding
+
+ if prior_do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ **prior_extra_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, prior_latents)
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeds = prior_latents
+
+ # done prior
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 8. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 9. Prepare image embeddings
+ image_embeds = self.noise_image_embeddings(
+ image_embeds=image_embeds,
+ noise_level=noise_level,
+ generator=generator,
+ )
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeds)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
+
+ # 10. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 11. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ latents = self.prepare_latents(
+ shape=shape,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ scheduler=self.scheduler,
+ )
+
+ # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 13. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=image_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..003c82ff4f8aedaa9874fb18d9abef858195e4b9
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -0,0 +1,810 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.utils.import_utils import is_accelerate_available
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.embeddings import get_timestep_embedding
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableUnCLIPImg2ImgPipeline
+
+ >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
+ ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
+ ... ) # TODO update model path
+ >>> pipe = pipe.to("cuda")
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ >>> response = requests.get(url)
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_image = init_image.resize((768, 512))
+
+ >>> prompt = "A fantasy landscape, trending on artstation"
+
+ >>> images = pipe(prompt, init_image).images
+ >>> images[0].save("fantasy_landscape.png")
+ ```
+"""
+
+
+class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ """
+ Pipeline for text-guided image to image generation using stable unCLIP.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ feature_extractor ([`CLIPImageProcessor`]):
+ Feature extractor for image pre-processing before being encoded.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ CLIP vision model for encoding images.
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
+ embeddings after the noise has been applied.
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`KarrasDiffusionSchedulers`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ """
+
+ _exclude_from_cpu_offload = ["image_normalizer"]
+
+ # image encoding components
+ feature_extractor: CLIPImageProcessor
+ image_encoder: CLIPVisionModelWithProjection
+
+ # image noising components
+ image_normalizer: StableUnCLIPImageNormalizer
+ image_noising_scheduler: KarrasDiffusionSchedulers
+
+ # regular denoising components
+ tokenizer: CLIPTokenizer
+ text_encoder: CLIPTextModel
+ unet: UNet2DConditionModel
+ scheduler: KarrasDiffusionSchedulers
+
+ vae: AutoencoderKL
+
+ def __init__(
+ self,
+ # image encoding components
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ # image noising components
+ image_normalizer: StableUnCLIPImageNormalizer,
+ image_noising_scheduler: KarrasDiffusionSchedulers,
+ # regular denoising components
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ # vae
+ vae: AutoencoderKL,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def _encode_image(
+ self,
+ image,
+ device,
+ batch_size,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ noise_level,
+ generator,
+ image_embeds,
+ ):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if isinstance(image, PIL.Image.Image):
+ # the image embedding should repeated so it matches the total batch size of the prompt
+ repeat_by = batch_size
+ else:
+ # assume the image input is already properly batched and just needs to be repeated so
+ # it matches the num_images_per_prompt.
+ #
+ # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
+ # `image_embeds`. If those happen to be common use cases, let's think harder about
+ # what the expected dimensions of inputs should be and how we handle the encoding.
+ repeat_by = num_images_per_prompt
+
+ if image_embeds is None:
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+
+ image_embeds = self.noise_image_embeddings(
+ image_embeds=image_embeds,
+ noise_level=noise_level,
+ generator=generator,
+ )
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ image_embeds = image_embeds.unsqueeze(1)
+ bs_embed, seq_len, _ = image_embeds.shape
+ image_embeds = image_embeds.repeat(1, repeat_by, 1)
+ image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
+ image_embeds = image_embeds.squeeze(1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeds)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ height,
+ width,
+ callback_steps,
+ noise_level,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two."
+ )
+
+ if prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
+ )
+
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
+ raise ValueError(
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
+ )
+
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Please make sure to define only one of the two."
+ )
+
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+
+ if image is not None:
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
+ def noise_image_embeddings(
+ self,
+ image_embeds: torch.Tensor,
+ noise_level: int,
+ noise: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ):
+ """
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
+ `noise_level` increases the variance in the final un-noised images.
+
+ The noise is applied in two ways
+ 1. A noise schedule is applied directly to the embeddings
+ 2. A vector of sinusoidal time embeddings are appended to the output.
+
+ In both cases, the amount of noise is controlled by the same `noise_level`.
+
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
+ """
+ if noise is None:
+ noise = randn_tensor(
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
+ )
+
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
+
+ self.image_normalizer.to(image_embeds.device)
+ image_embeds = self.image_normalizer.scale(image_embeds)
+
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
+
+ image_embeds = self.image_normalizer.unscale(image_embeds)
+
+ noise_level = get_timestep_embedding(
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
+ )
+
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
+ # but we might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ noise_level = noise_level.to(image_embeds.dtype)
+
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
+
+ return image_embeds
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 20,
+ guidance_scale: float = 10,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 0,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
+ used or prompt is initialized to `""`.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
+ process.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 10.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to `0`):
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
+ image_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
+ `latents`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if prompt is None and prompt_embeds is None:
+ prompt = len(image) * [""] if isinstance(image, list) else ""
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ image=image,
+ height=height,
+ width=width,
+ callback_steps=callback_steps,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ image_embeds=image_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ batch_size = batch_size * num_images_per_prompt
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Encoder input image
+ noise_level = torch.tensor([noise_level], device=device)
+ image_embeds = self._encode_image(
+ image=image,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ noise_level=noise_level,
+ generator=generator,
+ image_embeds=image_embeds,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size=batch_size,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=image_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c7b22d08d43ade5fe7979f5514ec973109fd82
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+ self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concept_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concept_idx]
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+ result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concept_idx] > 0:
+ result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+ adjustment = 0.01
+
+ for concept_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concept_idx]
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
+ result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concept_idx] > 0:
+ result_img["bad_concepts"].append(concept_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if torch.is_tensor(images) or torch.is_tensor(images[0]):
+ images[idx] = torch.zeros_like(images[idx]) # black image
+ else:
+ images[idx] = np.zeros(images[idx].shape) # black image
+
+ if any(has_nsfw_concepts):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ @torch.no_grad()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a8c3167954016b3b89f16caf8348661cd3a27ef
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
@@ -0,0 +1,112 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import linen as nn
+from flax.core.frozen_dict import FrozenDict
+from transformers import CLIPConfig, FlaxPreTrainedModel
+from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
+
+
+def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
+ norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
+ norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
+ return jnp.matmul(norm_emb_1, norm_emb_2.T)
+
+
+class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
+ config: CLIPConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
+ self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
+
+ self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
+ self.special_care_embeds = self.param(
+ "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
+ )
+
+ self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
+ self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
+
+ def __call__(self, clip_input):
+ pooled_output = self.vision_model(clip_input)[1]
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign image inputs
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
+ special_scores = jnp.round(special_scores, 3)
+ is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
+ # Use a lower threshold if an image has any special care concept
+ special_adjustment = is_special_care * 0.01
+
+ concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
+ concept_scores = jnp.round(concept_scores, 3)
+ has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
+
+ return has_nsfw_concepts
+
+
+class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
+ config_class = CLIPConfig
+ main_input_name = "clip_input"
+ module_class = FlaxStableDiffusionSafetyCheckerModule
+
+ def __init__(
+ self,
+ config: CLIPConfig,
+ input_shape: Optional[Tuple] = None,
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ if input_shape is None:
+ input_shape = (1, 224, 224, 3)
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensor
+ clip_input = jax.random.normal(rng, input_shape)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, clip_input)["params"]
+
+ return random_params
+
+ def __call__(
+ self,
+ clip_input,
+ params: dict = None,
+ ):
+ clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
+
+ return self.module.apply(
+ {"params": params or self.params},
+ jnp.array(clip_input, dtype=jnp.float32),
+ rngs={},
+ )
diff --git a/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7362df7e80e72719133f1804600a618fe161f668
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
@@ -0,0 +1,57 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+
+
+class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin):
+ """
+ This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP.
+
+ It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image
+ embeddings.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ embedding_dim: int = 768,
+ ):
+ super().__init__()
+
+ self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
+ self.std = nn.Parameter(torch.ones(1, embedding_dim))
+
+ def to(
+ self,
+ torch_device: Optional[Union[str, torch.device]] = None,
+ torch_dtype: Optional[torch.dtype] = None,
+ ):
+ self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
+ self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
+ return self
+
+ def scale(self, embeds):
+ embeds = (embeds - self.mean) * 1.0 / self.std
+ return embeds
+
+ def unscale(self, embeds):
+ embeds = (embeds * self.std) + self.mean
+ return embeds
diff --git a/diffusers/pipelines/stable_diffusion_safe/__init__.py b/diffusers/pipelines/stable_diffusion_safe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aecfeac112e53b2fc49278c1acaa95a6c0c7257
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/__init__.py
@@ -0,0 +1,71 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+class SafetyConfig(object):
+ WEAK = {
+ "sld_warmup_steps": 15,
+ "sld_guidance_scale": 20,
+ "sld_threshold": 0.0,
+ "sld_momentum_scale": 0.0,
+ "sld_mom_beta": 0.0,
+ }
+ MEDIUM = {
+ "sld_warmup_steps": 10,
+ "sld_guidance_scale": 1000,
+ "sld_threshold": 0.01,
+ "sld_momentum_scale": 0.3,
+ "sld_mom_beta": 0.4,
+ }
+ STRONG = {
+ "sld_warmup_steps": 7,
+ "sld_guidance_scale": 2000,
+ "sld_threshold": 0.025,
+ "sld_momentum_scale": 0.5,
+ "sld_mom_beta": 0.7,
+ }
+ MAX = {
+ "sld_warmup_steps": 0,
+ "sld_guidance_scale": 5000,
+ "sld_threshold": 1.0,
+ "sld_momentum_scale": 0.5,
+ "sld_mom_beta": 0.7,
+ }
+
+
+@dataclass
+class StableDiffusionSafePipelineOutput(BaseOutput):
+ """
+ Output class for Safe Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
+ (nsfw) content, or `None` if no safety check was performed or no images were flagged.
+ applied_safety_concept (`str`)
+ The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+ unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
+ applied_safety_concept: Optional[str]
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
+ from .safety_checker import SafeStableDiffusionSafetyChecker
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56217772ec01126c22f6242741e3e78242cd7a94
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ebc2d91b192ff22a75ebcf1933075d07f939c17
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..741a61b20498dd1ec9aa6c66413a64cf6ff4bbb9
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8c56960c474dd51e5ed783345f629fb28bdc79d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa82e4bbe818c9c78f8df967aa1c0ab852cbfc53
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8150b284dd66d8c9982d83735c3ef37d092df515
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..f172575bc6c737a35a6eaed79e4b8915d2b735e6
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -0,0 +1,705 @@
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import deprecate, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionSafePipelineOutput
+from .safety_checker import SafeStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionPipelineSafe(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Safe Latent Diffusion.
+
+ The implementation is based on the [`StableDiffusionPipeline`]
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: SafeStableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ safety_concept: Optional[str] = (
+ "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
+ " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
+ " abuse, brutality, cruelty"
+ )
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self._safety_text_concept = safety_concept
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ @property
+ def safety_concept(self):
+ r"""
+ Getter method for the safety concept used with SLD
+
+ Returns:
+ `str`: The text describing the safety concept
+ """
+ return self._safety_text_concept
+
+ @safety_concept.setter
+ def safety_concept(self, concept):
+ r"""
+ Setter method for the safety concept used with SLD
+
+ Args:
+ concept (`str`):
+ The text of the new safety concept
+ """
+ self._safety_text_concept = concept
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ enable_safety_guidance,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # Encode the safety concept text
+ if enable_safety_guidance:
+ safety_concept_input = self.tokenizer(
+ [self._safety_text_concept],
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
+
+ # duplicate safety embeddings for each generation per prompt, using mps friendly method
+ seq_len = safety_embeddings.shape[1]
+ safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance + sld, we need to do three forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing three forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings])
+
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
+ if self.safety_checker is not None:
+ images = image.copy()
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ flagged_images = np.zeros((2, *image.shape[1:]))
+ if any(has_nsfw_concept):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead."
+ f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}"
+ )
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
+ if has_nsfw_concept:
+ flagged_images[idx] = images[idx]
+ image[idx] = np.zeros(image[idx].shape) # black image
+ else:
+ has_nsfw_concept = None
+ flagged_images = None
+ return image, has_nsfw_concept, flagged_images
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def perform_safety_guidance(
+ self,
+ enable_safety_guidance,
+ safety_momentum,
+ noise_guidance,
+ noise_pred_out,
+ i,
+ sld_guidance_scale,
+ sld_warmup_steps,
+ sld_threshold,
+ sld_momentum_scale,
+ sld_mom_beta,
+ ):
+ # Perform SLD guidance
+ if enable_safety_guidance:
+ if safety_momentum is None:
+ safety_momentum = torch.zeros_like(noise_guidance)
+ noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
+ noise_pred_safety_concept = noise_pred_out[2]
+
+ # Equation 6
+ scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
+
+ # Equation 6
+ safety_concept_scale = torch.where(
+ (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
+ )
+
+ # Equation 4
+ noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
+
+ # Equation 7
+ noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
+
+ # Equation 8
+ safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
+
+ if i >= sld_warmup_steps: # Warmup
+ # Equation 3
+ noise_guidance = noise_guidance - noise_guidance_safety
+ return noise_guidance, safety_momentum
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ sld_guidance_scale: Optional[float] = 1000,
+ sld_warmup_steps: Optional[int] = 10,
+ sld_threshold: Optional[float] = 0.01,
+ sld_momentum_scale: Optional[float] = 0.3,
+ sld_mom_beta: Optional[float] = 0.4,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ sld_guidance_scale (`float`, *optional*, defaults to 1000):
+ Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
+ `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
+ disabled.
+ sld_warmup_steps (`int`, *optional*, defaults to 10):
+ Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
+ `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_threshold (`float`, *optional*, defaults to 0.01):
+ Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
+ is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_momentum_scale (`float`, *optional*, defaults to 0.3):
+ Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
+ momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_mom_beta (`float`, *optional*, defaults to 0.4):
+ Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
+ momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
+ if not enable_safety_guidance:
+ warnings.warn("Safety checker disabled!")
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ safety_momentum = None
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * (3 if enable_safety_guidance else 2))
+ if do_classifier_free_guidance
+ else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
+
+ # default classifier free guidance
+ noise_guidance = noise_pred_text - noise_pred_uncond
+
+ # Perform SLD guidance
+ if enable_safety_guidance:
+ if safety_momentum is None:
+ safety_momentum = torch.zeros_like(noise_guidance)
+ noise_pred_safety_concept = noise_pred_out[2]
+
+ # Equation 6
+ scale = torch.clamp(
+ torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
+ )
+
+ # Equation 6
+ safety_concept_scale = torch.where(
+ (noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
+ torch.zeros_like(scale),
+ scale,
+ )
+
+ # Equation 4
+ noise_guidance_safety = torch.mul(
+ (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
+ )
+
+ # Equation 7
+ noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
+
+ # Equation 8
+ safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
+
+ if i >= sld_warmup_steps: # Warmup
+ # Equation 3
+ noise_guidance = noise_guidance - noise_guidance_safety
+
+ noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept, flagged_images = self.run_safety_checker(
+ image, device, prompt_embeds.dtype, enable_safety_guidance
+ )
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+ if flagged_images is not None:
+ flagged_images = self.numpy_to_pil(flagged_images)
+
+ if not return_dict:
+ return (
+ image,
+ has_nsfw_concept,
+ self._safety_text_concept if enable_safety_guidance else None,
+ flagged_images,
+ )
+
+ return StableDiffusionSafePipelineOutput(
+ images=image,
+ nsfw_content_detected=has_nsfw_concept,
+ applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
+ unsafe_images=flagged_images,
+ )
diff --git a/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b0c547496a0202dbfa1d8525a92565b3df62cbb
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
@@ -0,0 +1,109 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class SafeStableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+ self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concept_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concept_idx]
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+ result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concept_idx] > 0:
+ result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+ adjustment = 0.01
+
+ for concept_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concept_idx]
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
+ result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concept_idx] > 0:
+ result_img["bad_concepts"].append(concept_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ return images, has_nsfw_concepts
+
+ @torch.no_grad()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ return images, has_nsfw_concepts
diff --git a/diffusers/pipelines/stable_diffusion_xl/__init__.py b/diffusers/pipelines/stable_diffusion_xl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b823cddd3af0915ce6c6afc6717024c34cff886
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_xl/__init__.py
@@ -0,0 +1,27 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+
+from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available
+
+
+@dataclass
+class StableDiffusionXLPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
+ from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
+ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
+ from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9698b2cfae0d63542a8fef997eb18d3b5249338d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5042985a37ca3373dc63354236b6da7b89a2ba6a
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..266ba7d3383c96e492cc106e9c94cb16c57941f7
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ff60e5db97d24f42e27bcbdd97a2e4f30d2b182
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cd0434f18b30d15a9e385c392b8fffcfca576a6
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96fde8de78d8ef0add27c1edbc060db10b8e6868
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b07704f35a52279818d9efcbce001193a7aafaa
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c12b4e004ff5ae13ee4a8aaace69344802173d4d
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eaaafb3c3656f15921ce6e777fc38418b4772f4b
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7019841445a23b7aa7aac1de2094733e68ec77eb
Binary files /dev/null and b/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..9863c663910f99a51d4a635ed50a241bce31f5d0
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -0,0 +1,811 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionXLPipelineOutput
+from .watermark import StableDiffusionXLWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ self.watermark = StableDiffusionXLWatermarker()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def encode_prompt(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
+ 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
+ denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The
+ denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of
+ Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ TODO
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None:
+ num_inference_steps = int(round(denoising_end * num_inference_steps))
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.watermark.apply_watermark(image)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..a106ee7ac5f632aac8cddd93b31f8be30a767b29
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -0,0 +1,943 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionXLPipelineOutput
+from .watermark import StableDiffusionXLWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+ >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
+
+ >>> init_image = load_image(url).convert("RGB")
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, image=init_image).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ _optional_components = ["tokenizer", "text_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.watermark = StableDiffusionXLWatermarker()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` cannot be None.")
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = int(round(denoising_start * num_inference_steps))
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.3,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
+ The image(s) to modify with the pipeline.
+ strength (`float`, *optional*, defaults to 0.3):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and
+ num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50)
+ denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed
+ that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly
+ beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
+ detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
+ 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
+ denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca.
+ 30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it
+ only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ TODO
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ TODO
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ TODO
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ TDOO
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 5. Prepare timesteps
+ original_num_steps = num_inference_steps # save for denoising_start/end later
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps, strength, device, denoising_start=denoising_start
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ add_noise = True if denoising_start is None else False
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ add_noise,
+ )
+ # 7. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 8. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype=prompt_embeds.dtype,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 9. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 9.1 Apply denoising_end
+ if denoising_end is not None and denoising_start is not None:
+ if denoising_start >= denoising_end:
+ raise ValueError(
+ f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}."
+ )
+
+ skipped_final_steps = int(round((1 - denoising_end) * original_num_steps))
+ num_inference_steps = num_inference_steps - skipped_final_steps
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
+ elif denoising_end is not None:
+ num_inference_steps = int(round(denoising_end * num_inference_steps))
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.watermark.apply_watermark(image)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f20660afc7f7556e469c1c193dcee1c763b0a3
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -0,0 +1,1209 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from . import StableDiffusionXLPipelineOutput
+from .watermark import StableDiffusionXLWatermarker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionXLInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ _optional_components = ["tokenizer", "text_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.watermark = StableDiffusionXLWatermarker()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ negative_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ bs_embed = pooled_prompt_embeds.shape[0]
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ masked_image_latents = None
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = int(round(denoising_start * num_inference_steps))
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and
+ num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50)
+ denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed
+ that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly
+ beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
+ detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
+ 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
+ denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca.
+ 30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it
+ only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. set timesteps
+ original_num_steps = num_inference_steps # save for denoising_start/end later
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps, strength, device, denoising_start=denoising_start
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype=prompt_embeds.dtype,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if denoising_end is not None and denoising_start is not None:
+ if denoising_start >= denoising_end:
+ raise ValueError(
+ f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}."
+ )
+
+ skipped_final_steps = int(round((1 - denoising_end) * original_num_steps))
+ num_inference_steps = num_inference_steps - skipped_final_steps
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
+ elif denoising_end is not None:
+ num_inference_steps = int(round(denoising_end * num_inference_steps))
+ timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion_xl/watermark.py b/diffusers/pipelines/stable_diffusion_xl/watermark.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc6c9bf649b161fbc1ae7e59b3de6ba5c22884fa
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_xl/watermark.py
@@ -0,0 +1,31 @@
+import numpy as np
+import torch
+from imwatermark import WatermarkEncoder
+
+
+# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
+WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
+
+
+class StableDiffusionXLWatermarker:
+ def __init__(self):
+ self.watermark = WATERMARK_BITS
+ self.encoder = WatermarkEncoder()
+
+ self.encoder.set_watermark("bits", self.watermark)
+
+ def apply_watermark(self, images: torch.FloatTensor):
+ # can't encode images that are smaller than 256
+ if images.shape[-1] < 256:
+ return images
+
+ images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()
+
+ images = [self.encoder.encode(image, "dwtDct") for image in images]
+
+ images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
+
+ images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
+ return images
diff --git a/diffusers/pipelines/stochastic_karras_ve/__init__.py b/diffusers/pipelines/stochastic_karras_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a63c1d24afb2c4f36b0e284f0985a3ff508f4c7
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/__init__.py
@@ -0,0 +1 @@
+from .pipeline_stochastic_karras_ve import KarrasVePipeline
diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b45bf0b5a92059d18ea3e387ad086adcb11530a1
Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41d7a3fa3ac852e688e83f93b802fd473ce8c5dd
Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c58ff51d48526a06030863fa616796649355c1f0
Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-310.pyc differ
diff --git a/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a757369efbbcf9927b89641a0a920c6ddd7f506
Binary files /dev/null and b/diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-38.pyc differ
diff --git a/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e0ab15eb9758c42116cf67aab6d9d8a5a6dad7d
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -0,0 +1,128 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...schedulers import KarrasVeScheduler
+from ...utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class KarrasVePipeline(DiffusionPipeline):
+ r"""
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`KarrasVeScheduler`]):
+ Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ # add type hints for linting
+ unet: UNet2DModel
+ scheduler: KarrasVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ # sample x_0 ~ N(0, sigma_0^2 * I)
+ sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # here sigma_t == t_i from the paper
+ sigma = self.scheduler.schedule[t]
+ sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
+
+ # 1. Select temporarily increased noise level sigma_hat
+ # 2. Add new noise to move from sample_i to sample_hat
+ sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
+
+ # 3. Predict the noise residual given the noise magnitude `sigma_hat`
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
+
+ # 4. Evaluate dx/dt at sigma_hat
+ # 5. Take Euler step from sigma to sigma_prev
+ step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
+
+ if sigma_prev != 0:
+ # 6. Apply 2nd order correction
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
+ step_output = self.scheduler.step_correct(
+ model_output,
+ sigma_hat,
+ sigma_prev,
+ sample_hat,
+ step_output.prev_sample,
+ step_output["derivative"],
+ )
+ sample = step_output.prev_sample
+
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/t2i_adapter/__init__.py b/diffusers/pipelines/t2i_adapter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4de661dbefab846cc72fa576675aad6cab1d134
--- /dev/null
+++ b/diffusers/pipelines/t2i_adapter/__init__.py
@@ -0,0 +1,14 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline
diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50233c532f5cd2f75ca4613ac7105f1cc8e95bb4
Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff267f1d748af7fc5eb7fa7c83b89a53cf55e497
Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b9047f536177ed3d9489970c134433453fd7c0b
Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-310.pyc differ
diff --git a/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..614a87a38195c4b67b9847885e4012fc4d4d0049
Binary files /dev/null and b/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-38.pyc differ
diff --git a/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c6fb99157f78ad83b426e5cba98855916859f4
--- /dev/null
+++ b/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -0,0 +1,774 @@
+# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ BaseOutput,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+
+@dataclass
+class StableDiffusionAdapterPipelineOutput(BaseOutput):
+ """
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from PIL import Image
+ >>> from diffusers.utils import load_image
+ >>> import torch
+ >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png"
+ ... )
+
+ >>> color_palette = image.resize((8, 8))
+ >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
+
+ >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4",
+ ... adapter=adapter,
+ ... torch_dtype=torch.float16,
+ ... )
+
+ >>> pipe.to("cuda")
+
+ >>> out_image = pipe(
+ ... "At night, glowing cubes in front of the beach",
+ ... image=color_palette,
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+class StableDiffusionAdapterPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ adapter_weights: Optional[List[float]] = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(adapter, (list, tuple)):
+ adapter = MultiAdapter(adapter, adapter_weights=adapter_weights)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ adapter=adapter,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.adapter, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, image)
+ device = self._execution_device
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ is_multi_adapter = isinstance(self.adapter, MultiAdapter)
+ if is_multi_adapter:
+ adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
+ n, c, h, w = adapter_input[0].shape
+ adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
+ else:
+ adapter_input = _preprocess_adapter_image(image, height, width).to(device)
+ adapter_input = adapter_input.to(self.adapter.dtype)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ adapter_state = self.adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=[state.clone() for state in adapter_state],
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/text_to_video_synthesis/__init__.py b/diffusers/pipelines/text_to_video_synthesis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d70c1c2ea2a8af8d69aebb915c9d6eacc52c14f8
--- /dev/null
+++ b/diffusers/pipelines/text_to_video_synthesis/__init__.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
+
+
+@dataclass
+class TextToVideoSDPipelineOutput(BaseOutput):
+ """
+ Output class for text to video pipelines.
+
+ Args:
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
+ a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
+ denotes the video length i.e., the number of frames.
+ """
+
+ frames: Union[List[np.ndarray], torch.FloatTensor]
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_text_to_video_synth import TextToVideoSDPipeline
+ from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401
+ from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebd8a72dcf14b5cacfbaea569b5b26f719649559
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b12a4dda8aaa4e029d9702dfe03b073626cfa118
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ba4a8d74f18ecbb3c873a45ce0a26eec2374f91
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-310.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9aa7cc5ce96ad129c3f96f363f5dcbfb597b8660
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-38.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4020b8e75dcc1cb257b638d91af03639a492af03
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-310.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef7227c624de539ebc347ff50efaa76660049946
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-38.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd87e8e7a841f0aef79dc2e24e5fa27ede71d228
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-310.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cd0482c1785c7aca6631edbbfbd9eea62d3021e
Binary files /dev/null and b/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-38.pyc differ
diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad7d563989255efd589c65034397d5f1a33b5e4
--- /dev/null
+++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -0,0 +1,653 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet3DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import TextToVideoSDPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import TextToVideoSDPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = TextToVideoSDPipeline.from_pretrained(
+ ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "Spiderman is surfing"
+ >>> video_frames = pipe(prompt).frames
+ >>> video_path = export_to_video(video_frames)
+ >>> video_path
+ ```
+"""
+
+
+def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+ # reshape to ncfhw
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
+ # unnormalize back to [0,1]
+ video = video.mul_(std).add_(mean)
+ video.clamp_(0, 1)
+ # prepare the final outputs
+ i, c, f, h, w = video.shape
+ images = video.permute(2, 3, 0, 4, 1).reshape(
+ f, h, i * w, c
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
+ return images
+
+
+class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Same as Stable Diffusion 2.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 16,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 9.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
+ usually at the expense of lower video quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated frames.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_images_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # reshape latents
+ bsz, channel, frames, width, height = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # reshape latents back
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ return TextToVideoSDPipelineOutput(frames=latents)
+
+ video_tensor = self.decode_latents(latents)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = tensor2vid(video_tensor)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (video,)
+
+ return TextToVideoSDPipelineOutput(frames=video)
diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..72a5b762d5048928a5bd5295acc303a60e5edf18
--- /dev/null
+++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -0,0 +1,731 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, UNet3DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from . import TextToVideoSDPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
+ >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "spiderman running in the desert"
+ >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames
+ >>> # safe low-res video
+ >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4")
+
+ >>> # let's offload the text-to-image model
+ >>> pipe.to("cpu")
+
+ >>> # and load the image-to-image model
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15"
+ ... )
+ >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode
+ >>> pipe.vae.enable_slicing()
+
+ >>> # now let's upscale it
+ >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
+
+ >>> # and denoise it
+ >>> video_frames = pipe(prompt, video=video, strength=0.6).frames
+ >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4")
+ >>> video_path
+ ```
+"""
+
+
+def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+ # reshape to ncfhw
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
+ # unnormalize back to [0,1]
+ video = video.mul_(std).add_(mean)
+ video.clamp_(0, 1)
+ # prepare the final outputs
+ i, c, f, h, w = video.shape
+ images = video.permute(2, 3, 0, 4, 1).reshape(
+ f, h, i * w, c
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
+ return images
+
+
+def preprocess_video(video):
+ supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image)
+
+ if isinstance(video, supported_formats):
+ video = [video]
+ elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
+ raise ValueError(
+ f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}"
+ )
+
+ if isinstance(video[0], PIL.Image.Image):
+ video = [np.array(frame) for frame in video]
+
+ if isinstance(video[0], np.ndarray):
+ video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)
+
+ if video.dtype == np.uint8:
+ video = np.array(video).astype(np.float32) / 255.0
+
+ if video.ndim == 4:
+ video = video[None, ...]
+
+ video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3))
+
+ elif isinstance(video[0], torch.Tensor):
+ video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0)
+
+ # don't need any preprocess if the video is latents
+ channel = video.shape[1]
+ if channel == 4:
+ return video
+
+ # move channels before num_frames
+ video = video.permute(0, 2, 1, 3, 4)
+
+ # normalize video
+ video = 2.0 * video - 1.0
+
+ return video
+
+
+class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Same as Stable Diffusion 2.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None):
+ video = video.to(device=device, dtype=dtype)
+
+ # change from (b, c, f, h, w) -> (b * f, c, w, h)
+ bsz, channel, frames, width, height = video.shape
+ video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
+
+ if video.shape[1] == 4:
+ init_latents = video
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(video).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4)
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ video: Union[List[np.ndarray], torch.FloatTensor] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 15.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ video: (`List[np.ndarray]` or `torch.FloatTensor`):
+ `video` frames or tensor representing a video batch, that will be used as the starting point for the
+ process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded
+ again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
+ usually at the expense of lower video quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated frames.
+ """
+ # 0. Default height and width to unet
+ num_images_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess video
+ video = preprocess_video(video)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # reshape latents
+ bsz, channel, frames, width, height = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # reshape latents back
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ return TextToVideoSDPipelineOutput(frames=latents)
+
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+
+ video_tensor = self.decode_latents(latents)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = tensor2vid(video_tensor)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (video,)
+
+ return TextToVideoSDPipelineOutput(frames=video)
diff --git a/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe7207f904f08032c3f125d64bf5f024a6b89b60
--- /dev/null
+++ b/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -0,0 +1,627 @@
+import copy
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from torch.nn.functional import grid_sample
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import BaseOutput
+
+
+def rearrange_0(tensor, f):
+ F, C, H, W = tensor.size()
+ tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4))
+ return tensor
+
+
+def rearrange_1(tensor):
+ B, C, F, H, W = tensor.size()
+ return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W))
+
+
+def rearrange_3(tensor, f):
+ F, D, C = tensor.size()
+ return torch.reshape(tensor, (F // f, f, D, C))
+
+
+def rearrange_4(tensor):
+ B, F, D, C = tensor.size()
+ return torch.reshape(tensor, (B * F, D, C))
+
+
+class CrossFrameAttnProcessor:
+ """
+ Cross frame attention processor. Each frame attends the first frame.
+
+ Args:
+ batch_size: The number that represents actual batch size, other than the frames.
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
+ 2, due to classifier-free guidance.
+ """
+
+ def __init__(self, batch_size=2):
+ self.batch_size = batch_size
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ query = attn.to_q(hidden_states)
+
+ is_cross_attention = encoder_hidden_states is not None
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # Cross Frame Attention
+ if not is_cross_attention:
+ video_length = key.size()[0] // self.batch_size
+ first_frame_index = [0] * video_length
+
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
+ key = rearrange_3(key, video_length)
+ key = key[:, first_frame_index]
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
+ value = rearrange_3(value, video_length)
+ value = value[:, first_frame_index]
+
+ # rearrange back to original shape
+ key = rearrange_4(key)
+ value = rearrange_4(value)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class CrossFrameAttnProcessor2_0:
+ """
+ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.
+
+ Args:
+ batch_size: The number that represents actual batch size, other than the frames.
+ For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
+ 2, due to classifier-free guidance.
+ """
+
+ def __init__(self, batch_size=2):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.batch_size = batch_size
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ inner_dim = hidden_states.shape[-1]
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ is_cross_attention = encoder_hidden_states is not None
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # Cross Frame Attention
+ if not is_cross_attention:
+ video_length = key.size()[0] // self.batch_size
+ first_frame_index = [0] * video_length
+
+ # rearrange keys to have batch and frames in the 1st and 2nd dims respectively
+ key = rearrange_3(key, video_length)
+ key = key[:, first_frame_index]
+ # rearrange values to have batch and frames in the 1st and 2nd dims respectively
+ value = rearrange_3(value, video_length)
+ value = value[:, first_frame_index]
+
+ # rearrange back to original shape
+ key = rearrange_4(key)
+ value = rearrange_4(value)
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+@dataclass
+class TextToVideoPipelineOutput(BaseOutput):
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+def coords_grid(batch, ht, wd, device):
+ # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+def warp_single_latent(latent, reference_flow):
+ """
+ Warp latent of a single frame with given flow
+
+ Args:
+ latent: latent code of a single frame
+ reference_flow: flow which to warp the latent with
+
+ Returns:
+ warped: warped latent
+ """
+ _, _, H, W = reference_flow.size()
+ _, _, h, w = latent.size()
+ coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype)
+
+ coords_t0 = coords0 + reference_flow
+ coords_t0[:, 0] /= W
+ coords_t0[:, 1] /= H
+
+ coords_t0 = coords_t0 * 2.0 - 1.0
+ coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear")
+ coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1))
+
+ warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
+ return warped
+
+
+def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype):
+ """
+ Create translation motion field
+
+ Args:
+ motion_field_strength_x: motion strength along x-axis
+ motion_field_strength_y: motion strength along y-axis
+ frame_ids: indexes of the frames the latents of which are being processed.
+ This is needed when we perform chunk-by-chunk inference
+ device: device
+ dtype: dtype
+
+ Returns:
+
+ """
+ seq_length = len(frame_ids)
+ reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype)
+ for fr_idx in range(seq_length):
+ reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx])
+ reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx])
+ return reference_flow
+
+
+def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents):
+ """
+ Creates translation motion and warps the latents accordingly
+
+ Args:
+ motion_field_strength_x: motion strength along x-axis
+ motion_field_strength_y: motion strength along y-axis
+ frame_ids: indexes of the frames the latents of which are being processed.
+ This is needed when we perform chunk-by-chunk inference
+ latents: latent codes of frames
+
+ Returns:
+ warped_latents: warped latents
+ """
+ motion_field = create_motion_field(
+ motion_field_strength_x=motion_field_strength_x,
+ motion_field_strength_y=motion_field_strength_y,
+ frame_ids=frame_ids,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ warped_latents = latents.clone().detach()
+ for i in range(len(warped_latents)):
+ warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None])
+ return warped_latents
+
+
+class TextToVideoZeroPipeline(StableDiffusionPipeline):
+ r"""
+ Pipeline for zero-shot text-to-video generation using Stable Diffusion.
+
+ This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
+ the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+ processor = (
+ CrossFrameAttnProcessor2_0(batch_size=2)
+ if hasattr(F, "scaled_dot_product_attention")
+ else CrossFrameAttnProcessor(batch_size=2)
+ )
+ self.unet.set_attn_processor(processor)
+
+ def forward_loop(self, x_t0, t0, t1, generator):
+ """
+ Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
+
+ Args:
+ x_t0: latent code at time t0
+ t0: t0
+ t1: t1
+ generator: torch.Generator object
+
+ Returns:
+ x_t1: forward process applied to x_t0 from time t0 to t1.
+ """
+ eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
+ alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
+ x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
+ return x_t1
+
+ def backward_loop(
+ self,
+ latents,
+ timesteps,
+ prompt_embeds,
+ guidance_scale,
+ callback,
+ callback_steps,
+ num_warmup_steps,
+ extra_step_kwargs,
+ cross_attention_kwargs=None,
+ ):
+ """
+ Perform backward process given list of time steps
+
+ Args:
+ latents: Latents at time timesteps[0].
+ timesteps: time steps, along which to perform backward process.
+ prompt_embeds: Pre-generated text embeddings
+ guidance_scale:
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ extra_step_kwargs: extra_step_kwargs.
+ cross_attention_kwargs: cross_attention_kwargs.
+ num_warmup_steps: number of warmup steps.
+
+ Returns:
+ latents: latents of backward process output at time timesteps[-1]
+ """
+ do_classifier_free_guidance = guidance_scale > 1.0
+ num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
+ with self.progress_bar(total=num_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+ return latents.clone().detach()
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ video_length: Optional[int] = 8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ motion_field_strength_x: float = 12,
+ motion_field_strength_y: float = 12,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ t0: int = 44,
+ t1: int = 47,
+ frame_ids: Optional[List[int]] = None,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ video_length (`int`, *optional*, defaults to 8): The number of generated video frames
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"numpy"`):
+ The output format of the generated image. Choose between `"latent"` and `"numpy"`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ motion_field_strength_x (`float`, *optional*, defaults to 12):
+ Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
+ Sect. 3.3.1.
+ motion_field_strength_y (`float`, *optional*, defaults to 12):
+ Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
+ Sect. 3.3.1.
+ t0 (`int`, *optional*, defaults to 44):
+ Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ t1 (`int`, *optional*, defaults to 47):
+ Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
+ [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ frame_ids (`List[int]`, *optional*):
+ Indexes of the frames that are being generated. This is used when generating longer videos
+ chunk-by-chunk.
+
+ Returns:
+ [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
+ The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent
+ codes of generated image, and a list of `bool`s denoting whether the corresponding generated image
+ likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ assert video_length > 0
+ if frame_ids is None:
+ frame_ids = list(range(video_length))
+ assert len(frame_ids) == video_length
+
+ assert num_videos_per_prompt == 1
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ # Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # Perform the first backward process up to time T_1
+ x_1_t1 = self.backward_loop(
+ timesteps=timesteps[: -t1 - 1],
+ prompt_embeds=prompt_embeds,
+ latents=latents,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=num_warmup_steps,
+ )
+ scheduler_copy = copy.deepcopy(self.scheduler)
+
+ # Perform the second backward process up to time T_0
+ x_1_t0 = self.backward_loop(
+ timesteps=timesteps[-t1 - 1 : -t0 - 1],
+ prompt_embeds=prompt_embeds,
+ latents=x_1_t1,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=0,
+ )
+
+ # Propagate first frame latents at time T_0 to remaining frames
+ x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1)
+
+ # Add motion in latents at time T_0
+ x_2k_t0 = create_motion_field_and_warp_latents(
+ motion_field_strength_x=motion_field_strength_x,
+ motion_field_strength_y=motion_field_strength_y,
+ latents=x_2k_t0,
+ frame_ids=frame_ids[1:],
+ )
+
+ # Perform forward process up to time T_1
+ x_2k_t1 = self.forward_loop(
+ x_t0=x_2k_t0,
+ t0=timesteps[-t0 - 1].item(),
+ t1=timesteps[-t1 - 1].item(),
+ generator=generator,
+ )
+
+ # Perform backward process from time T_1 to 0
+ x_1k_t1 = torch.cat([x_1_t1, x_2k_t1])
+ b, l, d = prompt_embeds.size()
+ prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
+
+ self.scheduler = scheduler_copy
+ x_1k_0 = self.backward_loop(
+ timesteps=timesteps[-t1 - 1 :],
+ prompt_embeds=prompt_embeds,
+ latents=x_1k_t1,
+ guidance_scale=guidance_scale,
+ callback=callback,
+ callback_steps=callback_steps,
+ extra_step_kwargs=extra_step_kwargs,
+ num_warmup_steps=0,
+ )
+ latents = x_1k_0
+
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ else:
+ image = self.decode_latents(latents)
+ # Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/unclip/__init__.py b/diffusers/pipelines/unclip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..075e66bb680aca294b36aa7ad0abb8d0f651cd92
--- /dev/null
+++ b/diffusers/pipelines/unclip/__init__.py
@@ -0,0 +1,17 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline
+else:
+ from .pipeline_unclip import UnCLIPPipeline
+ from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline
+ from .text_proj import UnCLIPTextProjModel
diff --git a/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2085b76dfabfae74a7c9cb418296e470b1af24a
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f2ede158e4989d91231581c4f453c6e6f4f42d3
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..507fdd91c7f7a619e99a6c49f5642f56af39255e
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d93a6a4beda48656b367e5182f6792c218ec9ae9
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be06f106c2689b68daa661ebcd28dbb563fac5c6
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8300f16906ccd5eae78cb8d2e2a4a7732ac69d4e
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbd8e45c4a4d3661fadfe6e30f6da8de61380dfa
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17d05790af5ac31577478551db9394f1203ebd8f
Binary files /dev/null and b/diffusers/pipelines/unclip/__pycache__/text_proj.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unclip/pipeline_unclip.py b/diffusers/pipelines/unclip/pipeline_unclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3b67bebfc8dc9bbc530fdc1f283bd356662cd2a
--- /dev/null
+++ b/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -0,0 +1,493 @@
+# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.pipeline_utils import ImagePipelineOutput
+from ...schedulers import UnCLIPScheduler
+from ...utils import logging, randn_tensor
+from .text_proj import UnCLIPTextProjModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class UnCLIPPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using unCLIP
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ prior_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ _exclude_from_cpu_offload = ["prior"]
+
+ prior: PriorTransformer
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ prior_scheduler: UnCLIPScheduler
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ prior_scheduler=prior_scheduler,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prior_num_inference_steps: int = 25,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prior_latents: Optional[torch.FloatTensor] = None,
+ decoder_latents: Optional[torch.FloatTensor] = None,
+ super_res_latents: Optional[torch.FloatTensor] = None,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation. This can only be left undefined if
+ `text_model_output` and `text_attention_mask` is passed.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
+ Pre-generated noisy latents to be used as inputs for the prior.
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ text_model_output (`CLIPTextModelOutput`, *optional*):
+ Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
+ can be passed for tasks like text embedding interpolations. Make sure to also pass
+ `text_attention_mask` in this case. `prompt` can the be left to `None`.
+ text_attention_mask (`torch.Tensor`, *optional*):
+ Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
+ masks are necessary when passing `text_model_output`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+ if prompt is not None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ else:
+ batch_size = text_model_output[0].shape[0]
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ # done prior
+
+ # decoder
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ decoder_latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ decoder_latents,
+ self.decoder_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ super_res_latents,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
new file mode 100644
index 0000000000000000000000000000000000000000..580417f517becdc159b727ea6e59d3430c6b1076
--- /dev/null
+++ b/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
@@ -0,0 +1,420 @@
+# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Union
+
+import PIL
+import torch
+from torch.nn import functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...models import UNet2DConditionModel, UNet2DModel
+from ...pipelines import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import UnCLIPScheduler
+from ...utils import logging, randn_tensor
+from .text_proj import UnCLIPTextProjModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class UnCLIPImageVariationPipeline(DiffusionPipeline):
+ """
+ Pipeline to generate variations from an input image using unCLIP
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `image_encoder`.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ feature_extractor: CLIPImageProcessor
+ image_encoder: CLIPVisionModelWithProjection
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ def __init__(
+ self,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if image_embeddings is None:
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+
+ image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
+ num_images_per_prompt: int = 1,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[torch.Generator] = None,
+ decoder_latents: Optional[torch.FloatTensor] = None,
+ super_res_latents: Optional[torch.FloatTensor] = None,
+ image_embeddings: Optional[torch.Tensor] = None,
+ decoder_guidance_scale: float = 8.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
+ configuration of
+ [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
+ `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ image_embeddings (`torch.Tensor`, *optional*):
+ Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
+ can be passed for tasks like image interpolations. `image` can the be left to `None`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+ if image is not None:
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ else:
+ batch_size = image_embeddings.shape[0]
+
+ prompt = [""] * batch_size
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance
+ )
+
+ image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)
+
+ # decoder
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ if decoder_latents is None:
+ decoder_latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ decoder_latents,
+ self.decoder_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ if super_res_latents is None:
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ super_res_latents,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/unclip/text_proj.py b/diffusers/pipelines/unclip/text_proj.py
new file mode 100644
index 0000000000000000000000000000000000000000..0414559500c16484dd326f72d04a5306dc14682e
--- /dev/null
+++ b/diffusers/pipelines/unclip/text_proj.py
@@ -0,0 +1,86 @@
+# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+
+
+class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
+ """
+ Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
+ decoder.
+
+ For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ *,
+ clip_extra_context_tokens: int = 4,
+ clip_embeddings_dim: int = 768,
+ time_embed_dim: int,
+ cross_attention_dim,
+ ):
+ super().__init__()
+
+ self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim))
+
+ # parameters for additional clip time embeddings
+ self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim)
+ self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim)
+
+ # parameters for encoder hidden states
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.clip_extra_context_tokens_proj = nn.Linear(
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
+ )
+ self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
+ self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance):
+ if do_classifier_free_guidance:
+ # Add the classifier free guidance embeddings to the image embeddings
+ image_embeddings_batch_size = image_embeddings.shape[0]
+ classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0)
+ classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand(
+ image_embeddings_batch_size, -1
+ )
+ image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
+
+ # The image embeddings batch size and the text embeddings batch size are equal
+ assert image_embeddings.shape[0] == prompt_embeds.shape[0]
+
+ batch_size = prompt_embeds.shape[0]
+
+ # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
+ # adding CLIP embeddings to the existing timestep embedding, ...
+ time_projected_prompt_embeds = self.embedding_proj(prompt_embeds)
+ time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
+ additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds
+
+ # ... and by projecting CLIP embeddings into four
+ # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
+ clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
+ clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
+ clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)
+
+ text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
+ text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
+ text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)
+
+ return text_encoder_hidden_states, additive_clip_time_embeddings
diff --git a/diffusers/pipelines/unidiffuser/__init__.py b/diffusers/pipelines/unidiffuser/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a774e3274030153d20618024b8c2bc6385ef367a
--- /dev/null
+++ b/diffusers/pipelines/unidiffuser/__init__.py
@@ -0,0 +1,20 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import (
+ ImageTextPipelineOutput,
+ UniDiffuserPipeline,
+ )
+else:
+ from .modeling_text_decoder import UniDiffuserTextDecoder
+ from .modeling_uvit import UniDiffuserModel, UTransformer2DModel
+ from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44c45f26d08a473339348ec6f11f058e7fa3c677
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f0061d9c3db9693f42a3f807619530a4015de8e
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1111fbb6be89b16a0161a8818577e16accc10b1f
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4677e78df80be1364c77474856d2078b80767a3
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_text_decoder.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bb588b0d2d9894ea55ca1e65bda09510f140be9
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1323a127f051c964cb645b1f00499f5125f9914
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/modeling_uvit.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d445d9019596de4921230134021c146b55e2f882
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-310.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0af0456a58a8db7011979345d2ef9a452fccd87
Binary files /dev/null and b/diffusers/pipelines/unidiffuser/__pycache__/pipeline_unidiffuser.cpython-38.pyc differ
diff --git a/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b962f6e065621c8fc83775f555bbd732ccc8a26
--- /dev/null
+++ b/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
@@ -0,0 +1,296 @@
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import nn
+from transformers import GPT2Config, GPT2LMHeadModel
+from transformers.modeling_utils import ModuleUtilsMixin
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+
+
+# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
+class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
+ """
+ Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
+ generate text from the UniDiffuser image-text embedding.
+
+ Parameters:
+ prefix_length (`int`):
+ Max number of prefix tokens that will be supplied to the model.
+ prefix_inner_dim (`int`):
+ The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
+ CLIP text encoder.
+ prefix_hidden_dim (`int`, *optional*):
+ Hidden dim of the MLP if we encode the prefix.
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
+ n_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ n_embd (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_inner (`int`, *optional*, defaults to None):
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+ activation_function (`str`, *optional*, defaults to `"gelu"`):
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
+ Scale attention weights by dividing by sqrt(hidden_size)..
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
+ dot-product/softmax to float() when training with mixed precision.
+ """
+
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
+
+ @register_to_config
+ def __init__(
+ self,
+ prefix_length: int,
+ prefix_inner_dim: int,
+ prefix_hidden_dim: Optional[int] = None,
+ vocab_size: int = 50257, # Start of GPT2 config args
+ n_positions: int = 1024,
+ n_embd: int = 768,
+ n_layer: int = 12,
+ n_head: int = 12,
+ n_inner: Optional[int] = None,
+ activation_function: str = "gelu_new",
+ resid_pdrop: float = 0.1,
+ embd_pdrop: float = 0.1,
+ attn_pdrop: float = 0.1,
+ layer_norm_epsilon: float = 1e-5,
+ initializer_range: float = 0.02,
+ scale_attn_weights: bool = True,
+ use_cache: bool = True,
+ scale_attn_by_inverse_layer_idx: bool = False,
+ reorder_and_upcast_attn: bool = False,
+ ):
+ super().__init__()
+
+ self.prefix_length = prefix_length
+
+ if prefix_inner_dim != n_embd and prefix_hidden_dim is None:
+ raise ValueError(
+ f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and"
+ f" `n_embd`: {n_embd} are not equal."
+ )
+
+ self.prefix_inner_dim = prefix_inner_dim
+ self.prefix_hidden_dim = prefix_hidden_dim
+
+ self.encode_prefix = (
+ nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim)
+ if self.prefix_hidden_dim is not None
+ else nn.Identity()
+ )
+ self.decode_prefix = (
+ nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
+ )
+
+ gpt_config = GPT2Config(
+ vocab_size=vocab_size,
+ n_positions=n_positions,
+ n_embd=n_embd,
+ n_layer=n_layer,
+ n_head=n_head,
+ n_inner=n_inner,
+ activation_function=activation_function,
+ resid_pdrop=resid_pdrop,
+ embd_pdrop=embd_pdrop,
+ attn_pdrop=attn_pdrop,
+ layer_norm_epsilon=layer_norm_epsilon,
+ initializer_range=initializer_range,
+ scale_attn_weights=scale_attn_weights,
+ use_cache=use_cache,
+ scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
+ reorder_and_upcast_attn=reorder_and_upcast_attn,
+ )
+ self.transformer = GPT2LMHeadModel(gpt_config)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ prefix_embeds: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ ):
+ """
+ Args:
+ input_ids (`torch.Tensor` of shape `(N, max_seq_len)`):
+ Text tokens to use for inference.
+ prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`):
+ Prefix embedding to preprend to the embedded tokens.
+ attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
+ Attention mask for the prefix embedding.
+ labels (`torch.Tensor`, *optional*):
+ Labels to use for language modeling.
+ """
+ embedding_text = self.transformer.transformer.wte(input_ids)
+ hidden = self.encode_prefix(prefix_embeds)
+ prefix_embeds = self.decode_prefix(hidden)
+ embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1)
+
+ if labels is not None:
+ dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device)
+ labels = torch.cat((dummy_token, input_ids), dim=1)
+ out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask)
+ if self.prefix_hidden_dim is not None:
+ return out, hidden
+ else:
+ return out
+
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
+
+ def encode(self, prefix):
+ return self.encode_prefix(prefix)
+
+ @torch.no_grad()
+ def generate_captions(self, features, eos_token_id, device):
+ """
+ Generate captions given text embedding features. Returns list[L].
+
+ Args:
+ features (`torch.Tensor` of shape `(B, L, D)`):
+ Text embedding features to generate captions from.
+ eos_token_id (`int`):
+ The token ID of the EOS token for the text decoder model.
+ device:
+ Device to perform text generation on.
+
+ Returns:
+ `List[str]`: A list of strings generated from the decoder model.
+ """
+
+ features = torch.split(features, 1, dim=0)
+ generated_tokens = []
+ generated_seq_lengths = []
+ for feature in features:
+ feature = self.decode_prefix(feature.to(device)) # back to the clip feature
+ # Only support beam search for now
+ output_tokens, seq_lengths = self.generate_beam(
+ input_embeds=feature, device=device, eos_token_id=eos_token_id
+ )
+ generated_tokens.append(output_tokens[0])
+ generated_seq_lengths.append(seq_lengths[0])
+ generated_tokens = torch.stack(generated_tokens)
+ generated_seq_lengths = torch.stack(generated_seq_lengths)
+ return generated_tokens, generated_seq_lengths
+
+ @torch.no_grad()
+ def generate_beam(
+ self,
+ input_ids=None,
+ input_embeds=None,
+ device=None,
+ beam_size: int = 5,
+ entry_length: int = 67,
+ temperature: float = 1.0,
+ eos_token_id: Optional[int] = None,
+ ):
+ """
+ Generates text using the given tokenizer and text prompt or token embedding via beam search. This
+ implementation is based on the beam search implementation from the [original UniDiffuser
+ code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89).
+
+ Args:
+ eos_token_id (`int`, *optional*):
+ The token ID of the EOS token for the text decoder model.
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds`
+ must be supplied.
+ input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
+ An embedded representation to directly pass to the transformer as a prefix for beam search. One of
+ `input_ids` and `input_embeds` must be supplied.
+ device:
+ The device to perform beam search on.
+ beam_size (`int`, *optional*, defaults to `5`):
+ The number of best states to store during beam search.
+ entry_length (`int`, *optional*, defaults to `67`):
+ The number of iterations to run beam search.
+ temperature (`float`, *optional*, defaults to 1.0):
+ The temperature to use when performing the softmax over logits from the decoding model.
+
+ Returns:
+ `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated
+ token sequences sorted by score in descending order, and the second element is the sequence lengths
+ corresponding to those sequences.
+ """
+ # Generates text until stop_token is reached using beam search with the desired beam size.
+ stop_token_index = eos_token_id
+ tokens = None
+ scores = None
+ seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int)
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
+
+ if input_embeds is not None:
+ generated = input_embeds
+ else:
+ generated = self.transformer.transformer.wte(input_ids)
+
+ for i in range(entry_length):
+ outputs = self.transformer(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ logits = logits.softmax(-1).log()
+
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ generated = generated.expand(beam_size, *generated.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+ if tokens is None:
+ tokens = next_tokens
+ else:
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ tokens = tokens[next_tokens_source]
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ generated = generated[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+
+ next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
+ if is_stopped.all():
+ break
+
+ scores = scores / seq_lengths
+ order = scores.argsort(descending=True)
+ # tokens tensors are already padded to max_seq_length
+ output_texts = [tokens[i] for i in order]
+ output_texts = torch.stack(output_texts, dim=0)
+ seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype)
+ return output_texts, seq_lengths
diff --git a/diffusers/pipelines/unidiffuser/modeling_uvit.py b/diffusers/pipelines/unidiffuser/modeling_uvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7829f76ec12f946490618e0d03857777efdf219
--- /dev/null
+++ b/diffusers/pipelines/unidiffuser/modeling_uvit.py
@@ -0,0 +1,1196 @@
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+from ...models.attention import AdaLayerNorm, FeedForward
+from ...models.attention_processor import Attention
+from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
+from ...models.transformer_2d import Transformer2DModelOutput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ logger.warning(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect."
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
+ \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
+ generating the random values works best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ use_pos_embed=True,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.use_pos_embed = use_pos_embed
+ if self.use_pos_embed:
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ if self.use_pos_embed:
+ return latent + self.pos_embed
+ else:
+ return latent
+
+
+class SkipBlock(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+
+ self.skip_linear = nn.Linear(2 * dim, dim)
+
+ # Use torch.nn.LayerNorm for now, following the original code
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x, skip):
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
+ x = self.norm(x)
+
+ return x
+
+
+# Modified to support both pre-LayerNorm and post-LayerNorm configurations
+# Don't support AdaLayerNormZero for now
+# Modified from diffusers.models.attention.BasicTransformerBlock
+class UTransformerBlock(nn.Module):
+ r"""
+ A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:obj: `int`, *optional*):
+ The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:obj: `bool`, *optional*, defaults to `False`):
+ Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the query and key to float32 when performing the attention calculation.
+ norm_elementwise_affine (`bool`, *optional*):
+ Whether to use learnable per-element affine parameters during layer normalization.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ The layer norm implementation to use.
+ pre_layer_norm (`bool`, *optional*):
+ Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
+ as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g.
+ `pre_layer_norm = True`.
+ final_dropout (`bool`, *optional*):
+ Whether to use a final Dropout layer after the feedforward network.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ pre_layer_norm: bool = True,
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ self.pre_layer_norm = pre_layer_norm
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # 1. Self-Attn
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.attn2 = None
+
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ else:
+ self.norm2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ timestep=None,
+ cross_attention_kwargs=None,
+ class_labels=None,
+ ):
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+ else:
+ norm_hidden_states = hidden_states
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ if self.use_ada_layer_norm:
+ attn_output = self.norm1(attn_output, timestep)
+ else:
+ attn_output = self.norm1(attn_output)
+
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ else:
+ norm_hidden_states = hidden_states
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
+ # prepare attention mask here
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output)
+
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ norm_hidden_states = self.norm3(hidden_states)
+ else:
+ norm_hidden_states = hidden_states
+
+ ff_output = self.ff(norm_hidden_states)
+
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ ff_output = self.norm3(ff_output)
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+# Like UTransformerBlock except with LayerNorms on the residual backbone of the block
+# Modified from diffusers.models.attention.BasicTransformerBlock
+class UniDiffuserBlock(nn.Module):
+ r"""
+ A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the
+ LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser
+ implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104).
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:obj: `int`, *optional*):
+ The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:obj: `bool`, *optional*, defaults to `False`):
+ Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the query and key to float() when performing the attention calculation.
+ norm_elementwise_affine (`bool`, *optional*):
+ Whether to use learnable per-element affine parameters during layer normalization.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ The layer norm implementation to use.
+ pre_layer_norm (`bool`, *optional*):
+ Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
+ as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
+ (`pre_layer_norm = False`).
+ final_dropout (`bool`, *optional*):
+ Whether to use a final Dropout layer after the feedforward network.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ pre_layer_norm: bool = False,
+ final_dropout: bool = True,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ self.pre_layer_norm = pre_layer_norm
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # 1. Self-Attn
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.attn2 = None
+
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ else:
+ self.norm2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ timestep=None,
+ cross_attention_kwargs=None,
+ class_labels=None,
+ ):
+ # Following the diffusers transformer block implementation, put the LayerNorm on the
+ # residual backbone
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ if self.use_ada_layer_norm:
+ hidden_states = self.norm1(hidden_states, timestep)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ attn_output = self.attn1(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ # Following the diffusers transformer block implementation, put the LayerNorm on the
+ # residual backbone
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ if self.use_ada_layer_norm:
+ hidden_states = self.norm1(hidden_states, timestep)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ if self.attn2 is not None:
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
+ # prepare attention mask here
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 3. Feed-forward
+ # Pre-LayerNorm
+ if self.pre_layer_norm:
+ hidden_states = self.norm3(hidden_states)
+
+ ff_output = self.ff(hidden_states)
+
+ hidden_states = ff_output + hidden_states
+
+ # Post-LayerNorm
+ if not self.pre_layer_norm:
+ hidden_states = self.norm3(hidden_states)
+
+ return hidden_states
+
+
+# Modified from diffusers.models.transformer_2d.Transformer2DModel
+# Modify the transformer block structure to be U-Net like following U-ViT
+# Only supports patch-style input and torch.nn.LayerNorm currently
+# https://github.com/baofff/U-ViT
+class UTransformer2DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
+ to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
+ similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`]
+ layer and then reshaped to (b, t, d).
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input.
+ out_channels (`int`, *optional*):
+ The number of output channels; if `None`, defaults to `in_channels`.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ norm_num_groups (`int`, *optional*, defaults to `32`):
+ The number of groups to use when performing Group Normalization.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ patch_size (`int`, *optional*, defaults to 2):
+ The patch size to use in the patch embedding.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ use_linear_projection (int, *optional*): TODO: Not used
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used in each
+ transformer block.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the query and key to float() when performing the attention calculation.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
+ block_type (`str`, *optional*, defaults to `"unidiffuser"`):
+ The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
+ backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
+ behavior in `diffusers`.)
+ pre_layer_norm (`bool`, *optional*):
+ Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
+ as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
+ (`pre_layer_norm = False`).
+ norm_elementwise_affine (`bool`, *optional*):
+ Whether to use learnable per-element affine parameters during layer normalization.
+ use_patch_pos_embed (`bool`, *optional*):
+ Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
+ final_dropout (`bool`, *optional*):
+ Whether to use a final Dropout layer after the feedforward network.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = 2,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ block_type: str = "unidiffuser",
+ pre_layer_norm: bool = False,
+ norm_elementwise_affine: bool = True,
+ use_patch_pos_embed=False,
+ ff_final_dropout: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Input
+ # Only support patch input of shape (batch_size, num_channels, height, width) for now
+ assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size."
+
+ assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size"
+
+ # 2. Define input layers
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ use_pos_embed=use_patch_pos_embed,
+ )
+
+ # 3. Define transformers blocks
+ # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block,
+ # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in
+ # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.).
+ # Quick hack to make the transformer block type configurable
+ if block_type == "unidiffuser":
+ block_cls = UniDiffuserBlock
+ else:
+ block_cls = UTransformerBlock
+ self.transformer_in_blocks = nn.ModuleList(
+ [
+ block_cls(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ pre_layer_norm=pre_layer_norm,
+ norm_elementwise_affine=norm_elementwise_affine,
+ final_dropout=ff_final_dropout,
+ )
+ for d in range(num_layers // 2)
+ ]
+ )
+
+ self.transformer_mid_block = block_cls(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ pre_layer_norm=pre_layer_norm,
+ norm_elementwise_affine=norm_elementwise_affine,
+ final_dropout=ff_final_dropout,
+ )
+
+ # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs
+ # before each transformer out_block.
+ self.transformer_out_blocks = nn.ModuleList(
+ [
+ nn.ModuleDict(
+ {
+ "skip": SkipBlock(
+ inner_dim,
+ ),
+ "block": block_cls(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ pre_layer_norm=pre_layer_norm,
+ norm_elementwise_affine=norm_elementwise_affine,
+ final_dropout=ff_final_dropout,
+ ),
+ }
+ )
+ for d in range(num_layers // 2)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ # Following the UniDiffuser U-ViT implementation, we process the transformer output with
+ # a LayerNorm layer with per-element affine params
+ self.norm_out = nn.LayerNorm(inner_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ class_labels=None,
+ cross_attention_kwargs=None,
+ return_dict: bool = True,
+ hidden_states_is_embedding: bool = False,
+ unpatchify: bool = True,
+ ):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
+ conditioning.
+ cross_attention_kwargs (*optional*):
+ Keyword arguments to supply to the cross attention layers, if used.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+ hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
+ Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
+ ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
+ transformer blocks.
+ unpatchify (`bool`, *optional*, defaults to `True`):
+ Whether to unpatchify the transformer output.
+
+ Returns:
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 0. Check inputs
+
+ if not unpatchify and return_dict:
+ raise ValueError(
+ f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when"
+ f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)"
+ " rather than (batch_size, num_channels, height, width)."
+ )
+
+ # 1. Input
+ if not hidden_states_is_embedding:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+
+ # In ("downsample") blocks
+ skips = []
+ for in_block in self.transformer_in_blocks:
+ hidden_states = in_block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+ skips.append(hidden_states)
+
+ # Mid block
+ hidden_states = self.transformer_mid_block(hidden_states)
+
+ # Out ("upsample") blocks
+ for out_block in self.transformer_out_blocks:
+ hidden_states = out_block["skip"](hidden_states, skips.pop())
+ hidden_states = out_block["block"](
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic
+ hidden_states = self.norm_out(hidden_states)
+ # hidden_states = self.proj_out(hidden_states)
+
+ if unpatchify:
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ else:
+ output = hidden_states
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class UniDiffuserModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a
+ modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the
+ CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details).
+
+ Parameters:
+ text_dim (`int`): The hidden dimension of the CLIP text model used to embed images.
+ clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts.
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input.
+ out_channels (`int`, *optional*):
+ The number of output channels; if `None`, defaults to `in_channels`.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ norm_num_groups (`int`, *optional*, defaults to `32`):
+ The number of groups to use when performing Group Normalization.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ patch_size (`int`, *optional*, defaults to 2):
+ The patch size to use in the patch embedding.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ use_linear_projection (int, *optional*): TODO: Not used
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used in each
+ transformer block.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the query and key to float32 when performing the attention calculation.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
+ block_type (`str`, *optional*, defaults to `"unidiffuser"`):
+ The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
+ backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
+ behavior in `diffusers`.)
+ pre_layer_norm (`bool`, *optional*):
+ Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
+ as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
+ (`pre_layer_norm = False`).
+ norm_elementwise_affine (`bool`, *optional*):
+ Whether to use learnable per-element affine parameters during layer normalization.
+ use_patch_pos_embed (`bool`, *optional*):
+ Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
+ ff_final_dropout (`bool`, *optional*):
+ Whether to use a final Dropout layer after the feedforward network.
+ use_data_type_embedding (`bool`, *optional*):
+ Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1
+ is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type`
+ argument, which can either be `1` to use the weights trained on non-publically-available data or `0`
+ otherwise. This argument is subsequently embedded by the data type embedding, if used.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ text_dim: int = 768,
+ clip_img_dim: int = 512,
+ num_text_tokens: int = 77,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ block_type: str = "unidiffuser",
+ pre_layer_norm: bool = False,
+ use_timestep_embedding=False,
+ norm_elementwise_affine: bool = True,
+ use_patch_pos_embed=False,
+ ff_final_dropout: bool = True,
+ use_data_type_embedding: bool = False,
+ ):
+ super().__init__()
+
+ # 0. Handle dimensions
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size"
+ self.sample_size = sample_size
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ self.patch_size = patch_size
+ # Assume image is square...
+ self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size)
+
+ # 1. Define input layers
+ # 1.1 Input layers for text and image input
+ # For now, only support patch input for VAE latent image input
+ self.vae_img_in = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ use_pos_embed=use_patch_pos_embed,
+ )
+ self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim)
+ self.text_in = nn.Linear(text_dim, self.inner_dim)
+
+ # 1.2. Timestep embeddings for t_img, t_text
+ self.timestep_img_proj = Timesteps(
+ self.inner_dim,
+ flip_sin_to_cos=True,
+ downscale_freq_shift=0,
+ )
+ self.timestep_img_embed = (
+ TimestepEmbedding(
+ self.inner_dim,
+ 4 * self.inner_dim,
+ out_dim=self.inner_dim,
+ )
+ if use_timestep_embedding
+ else nn.Identity()
+ )
+
+ self.timestep_text_proj = Timesteps(
+ self.inner_dim,
+ flip_sin_to_cos=True,
+ downscale_freq_shift=0,
+ )
+ self.timestep_text_embed = (
+ TimestepEmbedding(
+ self.inner_dim,
+ 4 * self.inner_dim,
+ out_dim=self.inner_dim,
+ )
+ if use_timestep_embedding
+ else nn.Identity()
+ )
+
+ # 1.3. Positional embedding
+ self.num_text_tokens = num_text_tokens
+ self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim))
+ self.pos_embed_drop = nn.Dropout(p=dropout)
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary
+ self.use_data_type_embedding = use_data_type_embedding
+ if self.use_data_type_embedding:
+ self.data_type_token_embedding = nn.Embedding(2, self.inner_dim)
+ self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim))
+
+ # 2. Define transformer blocks
+ self.transformer = UTransformer2DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ sample_size=sample_size,
+ num_vector_embeds=num_vector_embeds,
+ patch_size=patch_size,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ block_type=block_type,
+ pre_layer_norm=pre_layer_norm,
+ norm_elementwise_affine=norm_elementwise_affine,
+ use_patch_pos_embed=use_patch_pos_embed,
+ ff_final_dropout=ff_final_dropout,
+ )
+
+ # 3. Define output layers
+ patch_dim = (patch_size**2) * out_channels
+ self.vae_img_out = nn.Linear(self.inner_dim, patch_dim)
+ self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim)
+ self.text_out = nn.Linear(self.inner_dim, text_dim)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed"}
+
+ def forward(
+ self,
+ latent_image_embeds: torch.FloatTensor,
+ image_embeds: torch.FloatTensor,
+ prompt_embeds: torch.FloatTensor,
+ timestep_img: Union[torch.Tensor, float, int],
+ timestep_text: Union[torch.Tensor, float, int],
+ data_type: Optional[Union[torch.Tensor, float, int]] = 1,
+ encoder_hidden_states=None,
+ cross_attention_kwargs=None,
+ ):
+ """
+ Args:
+ latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`):
+ Latent image representation from the VAE encoder.
+ image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`):
+ CLIP-embedded image representation (unsqueezed in the first dimension).
+ prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`):
+ CLIP-embedded text representation.
+ timestep_img (`torch.long` or `float` or `int`):
+ Current denoising step for the image.
+ timestep_text (`torch.long` or `float` or `int`):
+ Current denoising step for the text.
+ data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`):
+ Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data,
+ or `0` otherwise.
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ cross_attention_kwargs (*optional*):
+ Keyword arguments to supply to the cross attention layers, if used.
+
+
+ Returns:
+ `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE
+ image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text
+ embedding.
+ """
+ batch_size = latent_image_embeds.shape[0]
+
+ # 1. Input
+ # 1.1. Map inputs to shape (B, N, inner_dim)
+ vae_hidden_states = self.vae_img_in(latent_image_embeds)
+ clip_hidden_states = self.clip_img_in(image_embeds)
+ text_hidden_states = self.text_in(prompt_embeds)
+
+ num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1)
+
+ # 1.2. Encode image timesteps to single token (B, 1, inner_dim)
+ if not torch.is_tensor(timestep_img):
+ timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device)
+
+ timestep_img_token = self.timestep_img_proj(timestep_img)
+ # t_img_token does not contain any weights and will always return f32 tensors
+ # but time_embedding might be fp16, so we need to cast here.
+ timestep_img_token = timestep_img_token.to(dtype=self.dtype)
+ timestep_img_token = self.timestep_img_embed(timestep_img_token)
+ timestep_img_token = timestep_img_token.unsqueeze(dim=1)
+
+ # 1.3. Encode text timesteps to single token (B, 1, inner_dim)
+ if not torch.is_tensor(timestep_text):
+ timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device)
+
+ timestep_text_token = self.timestep_text_proj(timestep_text)
+ # t_text_token does not contain any weights and will always return f32 tensors
+ # but time_embedding might be fp16, so we need to cast here.
+ timestep_text_token = timestep_text_token.to(dtype=self.dtype)
+ timestep_text_token = self.timestep_text_embed(timestep_text_token)
+ timestep_text_token = timestep_text_token.unsqueeze(dim=1)
+
+ # 1.4. Concatenate all of the embeddings together.
+ if self.use_data_type_embedding:
+ assert data_type is not None, "data_type must be supplied if the model uses a data type embedding"
+ if not torch.is_tensor(data_type):
+ data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device)
+
+ data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1)
+ hidden_states = torch.cat(
+ [
+ timestep_img_token,
+ timestep_text_token,
+ data_type_token,
+ text_hidden_states,
+ clip_hidden_states,
+ vae_hidden_states,
+ ],
+ dim=1,
+ )
+ else:
+ hidden_states = torch.cat(
+ [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states],
+ dim=1,
+ )
+
+ # 1.5. Prepare the positional embeddings and add to hidden states
+ # Note: I think img_vae should always have the proper shape, so there's no need to interpolate
+ # the position embeddings.
+ if self.use_data_type_embedding:
+ pos_embed = torch.cat(
+ [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1
+ )
+ else:
+ pos_embed = self.pos_embed
+ hidden_states = hidden_states + pos_embed
+ hidden_states = self.pos_embed_drop(hidden_states)
+
+ # 2. Blocks
+ hidden_states = self.transformer(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=None,
+ class_labels=None,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ hidden_states_is_embedding=True,
+ unpatchify=False,
+ )[0]
+
+ # 3. Output
+ # Split out the predicted noise representation.
+ if self.use_data_type_embedding:
+ (
+ t_img_token_out,
+ t_text_token_out,
+ data_type_token_out,
+ text_out,
+ img_clip_out,
+ img_vae_out,
+ ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
+ else:
+ t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split(
+ (1, 1, num_text_tokens, 1, num_img_tokens), dim=1
+ )
+
+ img_vae_out = self.vae_img_out(img_vae_out)
+
+ # unpatchify
+ height = width = int(img_vae_out.shape[1] ** 0.5)
+ img_vae_out = img_vae_out.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out)
+ img_vae_out = img_vae_out.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ img_clip_out = self.clip_img_out(img_clip_out)
+
+ text_out = self.text_out(text_out)
+
+ return img_vae_out, img_clip_out, text_out
diff --git a/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3632d74d1c12d2d0aab4bff2ad812e163ea48ee1
--- /dev/null
+++ b/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -0,0 +1,1382 @@
+import inspect
+import warnings
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ GPT2Tokenizer,
+)
+
+from ...models import AutoencoderKL
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+)
+from ...utils.outputs import BaseOutput
+from ..pipeline_utils import DiffusionPipeline
+from .modeling_text_decoder import UniDiffuserTextDecoder
+from .modeling_uvit import UniDiffuserModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ warnings.warn(
+ "The preprocess method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor.preprocess instead",
+ FutureWarning,
+ )
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+# New BaseOutput child class for joint image-text output
+@dataclass
+class ImageTextPipelineOutput(BaseOutput):
+ """
+ Output class for joint image-text pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ text (`List[str]` or `List[List[str]]`)
+ List of generated text strings of length `batch_size` or a list of list of strings whose outer list has
+ length `batch_size`.
+ """
+
+ images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
+ text: Optional[Union[List[str], List[List[str]]]]
+
+
+class UniDiffuserPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports
+ unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and
+ joint image-text generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This
+ is part of the UniDiffuser image representation, along with the CLIP vision encoding.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text
+ prompts.
+ image_encoder ([`CLIPVisionModel`]):
+ UniDiffuser uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode
+ images as part of its image representation, along with the VAE latent representation.
+ image_processor ([`CLIPImageProcessor`]):
+ CLIP image processor of class
+ [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
+ used to preprocess the image before CLIP encoding it with `image_encoder`.
+ clip_tokenizer ([`CLIPTokenizer`]):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
+ is used to tokenizer a prompt before encoding it with `text_encoder`.
+ text_decoder ([`UniDiffuserTextDecoder`]):
+ Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
+ embedding.
+ text_tokenizer ([`GPT2Tokenizer`]):
+ Tokenizer of class
+ [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
+ is used along with the `text_decoder` to decode text for text generation.
+ unet ([`UniDiffuserModel`]):
+ UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a
+ [`Transformer2DModel`] with U-Net-style skip connections between transformer layers.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The
+ original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_processor: CLIPImageProcessor,
+ clip_tokenizer: CLIPTokenizer,
+ text_decoder: UniDiffuserTextDecoder,
+ text_tokenizer: GPT2Tokenizer,
+ unet: UniDiffuserModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim:
+ raise ValueError(
+ f"The text encoder hidden size and text decoder prefix inner dim must be the same, but"
+ f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}"
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ clip_tokenizer=clip_tokenizer,
+ text_decoder=text_decoder,
+ text_tokenizer=text_tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ self.num_channels_latents = vae.config.latent_channels
+ self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
+ self.text_encoder_hidden_size = text_encoder.config.hidden_size
+ self.image_encoder_projection_dim = image_encoder.config.projection_dim
+ self.unet_resolution = unet.config.sample_size
+
+ self.text_intermediate_dim = self.text_encoder_hidden_size
+ if self.text_decoder.prefix_hidden_dim is not None:
+ self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim
+
+ self.mode = None
+
+ # TODO: handle safety checking?
+ self.safety_checker = None
+
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents):
+ r"""
+ Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set
+ mode will be used.
+ """
+ prompt_available = (prompt is not None) or (prompt_embeds is not None)
+ image_available = image is not None
+ input_available = prompt_available or image_available
+
+ prompt_latents_available = prompt_latents is not None
+ vae_latents_available = vae_latents is not None
+ clip_latents_available = clip_latents is not None
+ full_latents_available = latents is not None
+ image_latents_available = vae_latents_available and clip_latents_available
+ all_indv_latents_available = prompt_latents_available and image_latents_available
+
+ if self.mode is not None:
+ # Preferentially use the mode set by the user
+ mode = self.mode
+ elif prompt_available:
+ mode = "text2img"
+ elif image_available:
+ mode = "img2text"
+ else:
+ # Neither prompt nor image supplied, infer based on availability of latents
+ if full_latents_available or all_indv_latents_available:
+ mode = "joint"
+ elif prompt_latents_available:
+ mode = "text"
+ elif image_latents_available:
+ mode = "img"
+ else:
+ # No inputs or latents available
+ mode = "joint"
+
+ # Give warnings for ambiguous cases
+ if self.mode is None and prompt_available and image_available:
+ logger.warning(
+ f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
+ f" defaulting to mode '{mode}'."
+ )
+
+ if self.mode is None and not input_available:
+ if vae_latents_available != clip_latents_available:
+ # Exactly one of vae_latents and clip_latents is supplied
+ logger.warning(
+ f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none"
+ f" are expected to be supplied. Defaulting to mode '{mode}'."
+ )
+ elif not prompt_latents_available and not vae_latents_available and not clip_latents_available:
+ # No inputs or latents supplied
+ logger.warning(
+ f"No inputs or latents have been supplied, and mode has not been manually set,"
+ f" defaulting to mode '{mode}'."
+ )
+
+ return mode
+
+ # Functions to manually set the mode
+ def set_text_mode(self):
+ r"""Manually set the generation mode to unconditional ("marginal") text generation."""
+ self.mode = "text"
+
+ def set_image_mode(self):
+ r"""Manually set the generation mode to unconditional ("marginal") image generation."""
+ self.mode = "img"
+
+ def set_text_to_image_mode(self):
+ r"""Manually set the generation mode to text-conditioned image generation."""
+ self.mode = "text2img"
+
+ def set_image_to_text_mode(self):
+ r"""Manually set the generation mode to image-conditioned text generation."""
+ self.mode = "img2text"
+
+ def set_joint_mode(self):
+ r"""Manually set the generation mode to unconditional joint image-text generation."""
+ self.mode = "joint"
+
+ def reset_mode(self):
+ r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs."""
+ self.mode = None
+
+ def _infer_batch_size(
+ self,
+ mode,
+ prompt,
+ prompt_embeds,
+ image,
+ num_images_per_prompt,
+ num_prompts_per_image,
+ latents,
+ prompt_latents,
+ vae_latents,
+ clip_latents,
+ ):
+ r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`."""
+ if num_images_per_prompt is None:
+ num_images_per_prompt = 1
+ if num_prompts_per_image is None:
+ num_prompts_per_image = 1
+
+ assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer"
+ assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer"
+
+ if mode in ["text2img"]:
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ # Either prompt or prompt_embeds must be present for text2img.
+ batch_size = prompt_embeds.shape[0]
+ multiplier = num_images_per_prompt
+ elif mode in ["img2text"]:
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ else:
+ # Image must be available and type either PIL.Image.Image or torch.FloatTensor.
+ # Not currently supporting something like image_embeds.
+ batch_size = image.shape[0]
+ multiplier = num_prompts_per_image
+ elif mode in ["img"]:
+ if vae_latents is not None:
+ batch_size = vae_latents.shape[0]
+ elif clip_latents is not None:
+ batch_size = clip_latents.shape[0]
+ else:
+ batch_size = 1
+ multiplier = num_images_per_prompt
+ elif mode in ["text"]:
+ if prompt_latents is not None:
+ batch_size = prompt_latents.shape[0]
+ else:
+ batch_size = 1
+ multiplier = num_prompts_per_image
+ elif mode in ["joint"]:
+ if latents is not None:
+ batch_size = latents.shape[0]
+ elif prompt_latents is not None:
+ batch_size = prompt_latents.shape[0]
+ elif vae_latents is not None:
+ batch_size = vae_latents.shape[0]
+ elif clip_latents is not None:
+ batch_size = clip_latents.shape[0]
+ else:
+ batch_size = 1
+
+ if num_images_per_prompt == num_prompts_per_image:
+ multiplier = num_images_per_prompt
+ else:
+ multiplier = min(num_images_per_prompt, num_prompts_per_image)
+ logger.warning(
+ f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and"
+ f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to"
+ f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}."
+ )
+ return batch_size, multiplier
+
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ # self.tokenizer => self.clip_tokenizer
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.clip_tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.clip_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.clip_tokenizer.batch_decode(
+ untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.clip_tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
+ # Add num_prompts_per_image argument, sample from autoencoder moment distribution
+ def encode_image_vae_latents(
+ self,
+ image,
+ batch_size,
+ num_prompts_per_image,
+ dtype,
+ device,
+ do_classifier_free_guidance,
+ generator=None,
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_prompts_per_image
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ * self.vae.config.scaling_factor
+ for i in range(batch_size)
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+ # Scale image_latents by the VAE's scaling factor
+ image_latents = image_latents * self.vae.config.scaling_factor
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_image_latents = torch.zeros_like(image_latents)
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
+
+ return image_latents
+
+ def encode_image_clip_latents(
+ self,
+ image,
+ batch_size,
+ num_prompts_per_image,
+ dtype,
+ device,
+ generator=None,
+ ):
+ # Map image to CLIP embedding.
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ preprocessed_image = self.image_processor.preprocess(
+ image,
+ return_tensors="pt",
+ )
+ preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_prompts_per_image
+ if isinstance(generator, list):
+ image_latents = [
+ self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size)
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.image_encoder(**preprocessed_image).image_embeds
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ return image_latents
+
+ # Note that the CLIP latents are not decoded for image generation.
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ # Rename: decode_latents -> decode_image_latents
+ def decode_image_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_text_latents(
+ self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
+ ):
+ # Prepare latents for the CLIP embedded prompt.
+ shape = (batch_size * num_images_per_prompt, seq_len, hidden_size)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ # latents is assumed to have shace (B, L, D)
+ latents = latents.repeat(num_images_per_prompt, 1, 1)
+ latents = latents.to(device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument.
+ def prepare_image_vae_latents(
+ self,
+ batch_size,
+ num_prompts_per_image,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size * num_prompts_per_image,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ # latents is assumed to have shape (B, C, H, W)
+ latents = latents.repeat(num_prompts_per_image, 1, 1, 1)
+ latents = latents.to(device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_image_clip_latents(
+ self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None
+ ):
+ # Prepare latents for the CLIP embedded image.
+ shape = (batch_size * num_prompts_per_image, 1, clip_img_dim)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ # latents is assumed to have shape (B, L, D)
+ latents = latents.repeat(num_prompts_per_image, 1, 1)
+ latents = latents.to(device=device, dtype=dtype)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _split(self, x, height, width):
+ r"""
+ Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
+ and (B, 1, clip_img_dim)
+ """
+ batch_size = x.shape[0]
+ latent_height = height // self.vae_scale_factor
+ latent_width = width // self.vae_scale_factor
+ img_vae_dim = self.num_channels_latents * latent_height * latent_width
+
+ img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1)
+
+ img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
+ img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
+ return img_vae, img_clip
+
+ def _combine(self, img_vae, img_clip):
+ r"""
+ Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
+ clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim).
+ """
+ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
+ img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
+ return torch.concat([img_vae, img_clip], dim=-1)
+
+ def _split_joint(self, x, height, width):
+ r"""
+ Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae,
+ img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is
+ of shape (B, text_seq_len, text_dim).
+ """
+ batch_size = x.shape[0]
+ latent_height = height // self.vae_scale_factor
+ latent_width = width // self.vae_scale_factor
+ img_vae_dim = self.num_channels_latents * latent_height * latent_width
+ text_dim = self.text_encoder_seq_len * self.text_intermediate_dim
+
+ img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1)
+
+ img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
+ img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
+ text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim))
+ return img_vae, img_clip, text
+
+ def _combine_joint(self, img_vae, img_clip, text):
+ r"""
+ Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img,
+ clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B,
+ C * H * W + L_img * clip_img_dim + L_text * text_dim).
+ """
+ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
+ img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
+ text = torch.reshape(text, (text.shape[0], -1))
+ return torch.concat([img_vae, img_clip, text], dim=-1)
+
+ def _get_noise_pred(
+ self,
+ mode,
+ latents,
+ t,
+ prompt_embeds,
+ img_vae,
+ img_clip,
+ max_timestep,
+ data_type,
+ guidance_scale,
+ generator,
+ device,
+ height,
+ width,
+ ):
+ r"""
+ Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary.
+ """
+ if mode == "joint":
+ # Joint text-image generation
+ img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width)
+
+ img_vae_out, img_clip_out, text_out = self.unet(
+ img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type
+ )
+
+ x_out = self._combine_joint(img_vae_out, img_clip_out, text_out)
+
+ if guidance_scale <= 1.0:
+ return x_out
+
+ # Classifier-free guidance
+ img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
+ img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
+ text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+
+ _, _, text_out_uncond = self.unet(
+ img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
+ )
+
+ img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
+ img_vae_latents,
+ img_clip_latents,
+ text_T,
+ timestep_img=t,
+ timestep_text=max_timestep,
+ data_type=data_type,
+ )
+
+ x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond)
+
+ return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond
+ elif mode == "text2img":
+ # Text-conditioned image generation
+ img_vae_latents, img_clip_latents = self._split(latents, height, width)
+
+ img_vae_out, img_clip_out, text_out = self.unet(
+ img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type
+ )
+
+ img_out = self._combine(img_vae_out, img_clip_out)
+
+ if guidance_scale <= 1.0:
+ return img_out
+
+ # Classifier-free guidance
+ text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+
+ img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
+ img_vae_latents,
+ img_clip_latents,
+ text_T,
+ timestep_img=t,
+ timestep_text=max_timestep,
+ data_type=data_type,
+ )
+
+ img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond)
+
+ return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond
+ elif mode == "img2text":
+ # Image-conditioned text generation
+ img_vae_out, img_clip_out, text_out = self.unet(
+ img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type
+ )
+
+ if guidance_scale <= 1.0:
+ return text_out
+
+ # Classifier-free guidance
+ img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
+ img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
+
+ img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
+ img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
+ )
+
+ return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond
+ elif mode == "text":
+ # Unconditional ("marginal") text generation (no CFG)
+ img_vae_out, img_clip_out, text_out = self.unet(
+ img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
+ )
+
+ return text_out
+ elif mode == "img":
+ # Unconditional ("marginal") image generation (no CFG)
+ img_vae_latents, img_clip_latents = self._split(latents, height, width)
+
+ img_vae_out, img_clip_out, text_out = self.unet(
+ img_vae_latents,
+ img_clip_latents,
+ prompt_embeds,
+ timestep_img=t,
+ timestep_text=max_timestep,
+ data_type=data_type,
+ )
+
+ img_out = self._combine(img_vae_out, img_clip_out)
+ return img_out
+
+ def check_latents_shape(self, latents_name, latents, expected_shape):
+ latents_shape = latents.shape
+ expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension
+ expected_shape_str = ", ".join(str(dim) for dim in expected_shape)
+ if len(latents_shape) != expected_num_dims:
+ raise ValueError(
+ f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape"
+ f" {latents_shape} has {len(latents_shape)} dimensions."
+ )
+ for i in range(1, expected_num_dims):
+ if latents_shape[i] != expected_shape[i - 1]:
+ raise ValueError(
+ f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape"
+ f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}."
+ )
+
+ def check_inputs(
+ self,
+ mode,
+ prompt,
+ image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ latents=None,
+ prompt_latents=None,
+ vae_latents=None,
+ clip_latents=None,
+ ):
+ # Check inputs before running the generative process.
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if mode == "text2img":
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if mode == "img2text":
+ if image is None:
+ raise ValueError("`img2text` mode requires an image to be provided.")
+
+ # Check provided latents
+ latent_height = height // self.vae_scale_factor
+ latent_width = width // self.vae_scale_factor
+ full_latents_available = latents is not None
+ prompt_latents_available = prompt_latents is not None
+ vae_latents_available = vae_latents is not None
+ clip_latents_available = clip_latents is not None
+
+ if full_latents_available:
+ individual_latents_available = (
+ prompt_latents is not None or vae_latents is not None or clip_latents is not None
+ )
+ if individual_latents_available:
+ logger.warning(
+ "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and"
+ " `clip_latents`. The value of `latents` will override the value of any individually supplied latents."
+ )
+ # Check shape of full latents
+ img_vae_dim = self.num_channels_latents * latent_height * latent_width
+ text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size
+ latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim
+ latents_expected_shape = (latents_dim,)
+ self.check_latents_shape("latents", latents, latents_expected_shape)
+
+ # Check individual latent shapes, if present
+ if prompt_latents_available:
+ prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size)
+ self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape)
+
+ if vae_latents_available:
+ vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width)
+ self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape)
+
+ if clip_latents_available:
+ clip_latents_expected_shape = (1, self.image_encoder_projection_dim)
+ self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape)
+
+ if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available:
+ if vae_latents.shape[0] != clip_latents.shape[0]:
+ raise ValueError(
+ f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:"
+ f" {vae_latents.shape[0]} != {clip_latents.shape[0]}."
+ )
+
+ if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available:
+ if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]:
+ raise ValueError(
+ f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch"
+ f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}"
+ f" != {clip_latents.shape[0]}."
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ data_type: Optional[int] = 1,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 8.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ num_prompts_per_image: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_latents: Optional[torch.FloatTensor] = None,
+ vae_latents: Optional[torch.FloatTensor] = None,
+ clip_latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead. Required for text-conditioned image generation (`text2img`) mode.
+ image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
+ `Image`, or tensor representing an image batch. Required for image-conditioned text generation
+ (`img2text`) mode.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ data_type (`int`, *optional*, defaults to 1):
+ The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type
+ embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 8.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. Note that the original [UniDiffuser
+ paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`,
+ which satisfies `w = w' + 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`). Used in text-conditioned image generation (`text2img`) mode.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and
+ `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
+ supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated.
+ num_prompts_per_image (`int`, *optional*, defaults to 1):
+ The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and
+ `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
+ supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint
+ image-text generation. Can be used to tweak the same generation with different prompts. If not
+ provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note
+ that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override
+ the value of `prompt_latents`, `vae_latents`, and `clip_latents`.
+ prompt_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ vae_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ clip_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned
+ image generation (`text2img`) mode.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument. Used in text-conditioned image generation (`text2img`) mode.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`:
+ [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of generated texts.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet_resolution * self.vae_scale_factor
+ width = width or self.unet_resolution * self.vae_scale_factor
+
+ # 1. Check inputs
+ # Recalculate mode for each call to the pipeline.
+ mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents)
+ self.check_inputs(
+ mode,
+ prompt,
+ image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ latents,
+ prompt_latents,
+ vae_latents,
+ clip_latents,
+ )
+
+ # 2. Define call parameters
+ batch_size, multiplier = self._infer_batch_size(
+ mode,
+ prompt,
+ prompt_embeds,
+ image,
+ num_images_per_prompt,
+ num_prompts_per_image,
+ latents,
+ prompt_latents,
+ vae_latents,
+ clip_latents,
+ )
+ device = self._execution_device
+ reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img"
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ # Note that this differs from the formulation in the unidiffusers paper!
+ # do_classifier_free_guidance = guidance_scale > 1.0
+
+ # check if scheduler is in sigmas space
+ # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
+
+ # 3. Encode input prompt, if available; otherwise prepare text latents
+ if latents is not None:
+ # Overwrite individual latents
+ vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width)
+
+ if mode in ["text2img"]:
+ # 3.1. Encode input prompt, if available
+ assert prompt is not None or prompt_embeds is not None
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=multiplier,
+ do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ else:
+ # 3.2. Prepare text latent variables, if input not available
+ prompt_embeds = self.prepare_text_latents(
+ batch_size=batch_size,
+ num_images_per_prompt=multiplier,
+ seq_len=self.text_encoder_seq_len,
+ hidden_size=self.text_encoder_hidden_size,
+ dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision
+ device=device,
+ generator=generator,
+ latents=prompt_latents,
+ )
+
+ if reduce_text_emb_dim:
+ prompt_embeds = self.text_decoder.encode(prompt_embeds)
+
+ # 4. Encode image, if available; otherwise prepare image latents
+ if mode in ["img2text"]:
+ # 4.1. Encode images, if available
+ assert image is not None, "`img2text` requires a conditioning image"
+ # Encode image using VAE
+ image_vae = preprocess(image)
+ height, width = image_vae.shape[-2:]
+ image_vae_latents = self.encode_image_vae_latents(
+ image=image_vae,
+ batch_size=batch_size,
+ num_prompts_per_image=multiplier,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG
+ generator=generator,
+ )
+
+ # Encode image using CLIP
+ image_clip_latents = self.encode_image_clip_latents(
+ image=image,
+ batch_size=batch_size,
+ num_prompts_per_image=multiplier,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ )
+ # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
+ image_clip_latents = image_clip_latents.unsqueeze(1)
+ else:
+ # 4.2. Prepare image latent variables, if input not available
+ # Prepare image VAE latents in latent space
+ image_vae_latents = self.prepare_image_vae_latents(
+ batch_size=batch_size,
+ num_prompts_per_image=multiplier,
+ num_channels_latents=self.num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=vae_latents,
+ )
+
+ # Prepare image CLIP latents
+ image_clip_latents = self.prepare_image_clip_latents(
+ batch_size=batch_size,
+ num_prompts_per_image=multiplier,
+ clip_img_dim=self.image_encoder_projection_dim,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=clip_latents,
+ )
+
+ # 5. Set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ # max_timestep = timesteps[0]
+ max_timestep = self.scheduler.config.num_train_timesteps
+
+ # 6. Prepare latent variables
+ if mode == "joint":
+ latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds)
+ elif mode in ["text2img", "img"]:
+ latents = self._combine(image_vae_latents, image_clip_latents)
+ elif mode in ["img2text", "text"]:
+ latents = prompt_embeds
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}")
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # predict the noise residual
+ # Also applies classifier-free guidance as described in the UniDiffuser paper
+ noise_pred = self._get_noise_pred(
+ mode,
+ latents,
+ t,
+ prompt_embeds,
+ image_vae_latents,
+ image_clip_latents,
+ max_timestep,
+ data_type,
+ guidance_scale,
+ generator,
+ device,
+ height,
+ width,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ gen_image = None
+ gen_text = None
+ if mode == "joint":
+ image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
+
+ # Map latent VAE image back to pixel space
+ gen_image = self.decode_image_latents(image_vae_latents)
+
+ # Generate text using the text decoder
+ output_token_list, seq_lengths = self.text_decoder.generate_captions(
+ text_latents, self.text_tokenizer.eos_token_id, device=device
+ )
+ output_list = output_token_list.cpu().numpy()
+ gen_text = [
+ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
+ for output, length in zip(output_list, seq_lengths)
+ ]
+ elif mode in ["text2img", "img"]:
+ image_vae_latents, image_clip_latents = self._split(latents, height, width)
+ gen_image = self.decode_image_latents(image_vae_latents)
+ elif mode in ["img2text", "text"]:
+ text_latents = latents
+ output_token_list, seq_lengths = self.text_decoder.generate_captions(
+ text_latents, self.text_tokenizer.eos_token_id, device=device
+ )
+ output_list = output_token_list.cpu().numpy()
+ gen_text = [
+ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
+ for output, length in zip(output_list, seq_lengths)
+ ]
+
+ # 10. Convert to PIL
+ if output_type == "pil" and gen_image is not None:
+ gen_image = self.numpy_to_pil(gen_image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (gen_image, gen_text)
+
+ return ImageTextPipelineOutput(images=gen_image, text=gen_text)
diff --git a/diffusers/pipelines/versatile_diffusion/__init__.py b/diffusers/pipelines/versatile_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf9dcff59dbc922dcc7063a1e73560679a23696
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/__init__.py
@@ -0,0 +1,24 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import (
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ )
+else:
+ from .modeling_text_unet import UNetFlatConditionModel
+ from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
+ from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
+ from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
+ from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ee0ffa87ae55f15d22abd747d2c05adf0a520e9
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c089885c4847a05ce6d911fa0daf7c0d21e739e
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e898886cbc4826e03e33dfcc66967db4eb80a6a
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e479452e3eac4394a8f7c5b6fc0d60f8e81fcca8
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de794825198274c1b623392c3227eb6861fc1df6
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23c1c4a3932b6b123e00ef2cfa4bbdeec18e2200
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc2f9ce1bd55c76ff4dd32cc0ed6e5ea4232d1f4
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be2e2d93a44b9fb566f30f1bb8a8ee82d531eb13
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec606f5452f82490738289d7612ac12e85e0b906
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa6cb7369180dd15121d0ff00972111e789a09a9
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29cda28d5ce8c8bbd7357cfe688c763eb1b1fcba
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-310.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec0ea3eeb194cacb7b41cae359f5ccbd821f4ef1
Binary files /dev/null and b/diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-38.pyc differ
diff --git a/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..82628104eba2a30b99487868e6bb6db5f4b7214c
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -0,0 +1,1931 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin
+from ...models.activations import get_activation
+from ...models.attention import Attention
+from ...models.attention_processor import (
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ AttnProcessor,
+)
+from ...models.dual_transformer_2d import DualTransformer2DModel
+from ...models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from ...models.transformer_2d import Transformer2DModel
+from ...models.unet_2d_condition import UNet2DConditionOutput
+from ...utils import is_torch_version, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlockFlat":
+ return DownBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlockFlat":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")
+ return CrossAttnDownBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{down_block_type} is not supported.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlockFlat":
+ return UpBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlockFlat":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")
+ return CrossAttnUpBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ raise ValueError(f"{up_block_type} is not supported.")
+
+
+# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
+class UNetFlatConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
+ `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
+ [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockFlat",
+ "CrossAttnDownBlockFlat",
+ "CrossAttnDownBlockFlat",
+ "DownBlockFlat",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
+ " because of a naming issue as described in"
+ " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
+ " `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
+ f" {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
+ f" {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `only_cross_attention` as `down_block_types`."
+ f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
+ f" {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
+ f" {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
+ f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
+ f" {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = LinearMultiDim(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlockFlatCrossAttn":
+ self.mid_block = UNetMidBlockFlatCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
+ self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = LinearMultiDim(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNetFlatConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
+ " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
+ " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
+ " the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
+ " keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
+ " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
+ " requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
+ " the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlockFlat
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+
+class LinearMultiDim(nn.Linear):
+ def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):
+ in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)
+ if out_features is None:
+ out_features = in_features
+ out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)
+ self.in_features_multidim = in_features
+ self.out_features_multidim = out_features
+ super().__init__(np.array(in_features).prod(), np.array(out_features).prod())
+
+ def forward(self, input_tensor, *args, **kwargs):
+ shape = input_tensor.shape
+ n_dim = len(self.in_features_multidim)
+ input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)
+ output_tensor = super().forward(input_tensor)
+ output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)
+ return output_tensor
+
+
+class ResnetBlockFlat(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ time_embedding_norm="default",
+ use_in_shortcut=None,
+ second_dim=4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+
+ in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)
+ self.in_channels_prod = np.array(in_channels).prod()
+ self.channels_multidim = in_channels
+
+ if out_channels is not None:
+ out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)
+ out_channels_prod = np.array(out_channels).prod()
+ self.out_channels_multidim = out_channels
+ else:
+ out_channels_prod = self.in_channels_prod
+ self.out_channels_multidim = self.channels_multidim
+ self.time_embedding_norm = time_embedding_norm
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)
+ self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0)
+
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = (
+ self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
+ )
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, input_tensor, temb):
+ shape = input_tensor.shape
+ n_dim = len(self.channels_multidim)
+ input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)
+ input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)
+
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = input_tensor + hidden_states
+
+ output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
+ output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
+
+ return output_tensor
+
+
+# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
+class DownBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ LinearMultiDim(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
+class CrossAttnDownBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ LinearMultiDim(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals=None,
+ ):
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
+class UpBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
+class CrossAttnUpBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
+class UNetMidBlockFlatCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
+class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ self.attention_head_dim = attention_head_dim
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ self.num_heads = in_channels // self.attention_head_dim
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ cross_attention_dim=in_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ # attn
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ # resnet
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6b5e7863ebb9b53ba741138b0829eab509888c
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -0,0 +1,434 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
+from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
+from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPImageProcessor
+ text_encoder: CLIPTextModel
+ image_encoder: CLIPVisionModel
+ image_unet: UNet2DConditionModel
+ text_unet: UNet2DConditionModel
+ vae: AutoencoderKL
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ image_feature_extractor: CLIPImageProcessor,
+ text_encoder: CLIPTextModel,
+ image_encoder: CLIPVisionModel,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ image_feature_extractor=image_feature_extractor,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ @torch.no_grad()
+ def image_variation(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
+ The image prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe.image_variation(image, generator=generator).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ return VersatileDiffusionImageVariationPipeline(**components)(
+ image=image,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def text_to_image(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0]
+ >>> image.save("./astronaut.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ temp_pipeline = VersatileDiffusionTextToImagePipeline(**components)
+ output = temp_pipeline(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+ # swap the attention blocks back to the original state
+ temp_pipeline._swap_unet_attention_blocks()
+
+ return output
+
+ @torch.no_grad()
+ def dual_guided(
+ self,
+ prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ image: Union[str, List[str]],
+ text_to_image_strength: float = 0.5,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> text = "a red car in the sun"
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> text_to_image_strength = 0.75
+
+ >>> image = pipe.dual_guided(
+ ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
+ ... ).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components)
+ output = temp_pipeline(
+ prompt=prompt,
+ image=image,
+ text_to_image_strength=text_to_image_strength,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+ temp_pipeline._revert_dual_attention()
+
+ return output
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
new file mode 100644
index 0000000000000000000000000000000000000000..5986d66a61e79e90dd7b86884f621f956cd52cba
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -0,0 +1,557 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.utils.checkpoint
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .modeling_text_unet import UNetFlatConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPImageProcessor
+ text_encoder: CLIPTextModelWithProjection
+ image_encoder: CLIPVisionModelWithProjection
+ image_unet: UNet2DConditionModel
+ text_unet: UNetFlatConditionModel
+ vae: AutoencoderKL
+ scheduler: KarrasDiffusionSchedulers
+
+ _optional_components = ["text_unet"]
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ image_feature_extractor: CLIPImageProcessor,
+ text_encoder: CLIPTextModelWithProjection,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNetFlatConditionModel,
+ vae: AutoencoderKL,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ image_feature_extractor=image_feature_extractor,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ if self.text_unet is not None and (
+ "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
+ ):
+ # if loading from a universal checkpoint rather than a saved dual-guided pipeline
+ self._convert_to_dual_attention()
+
+ def remove_unused_weights(self):
+ self.register_modules(text_unet=None)
+
+ def _convert_to_dual_attention(self):
+ """
+ Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks
+ from both `image_unet` and `text_unet`
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, Transformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+
+ image_transformer = self.image_unet.get_submodule(parent_name)[index]
+ text_transformer = self.text_unet.get_submodule(parent_name)[index]
+
+ config = image_transformer.config
+ dual_transformer = DualTransformer2DModel(
+ num_attention_heads=config.num_attention_heads,
+ attention_head_dim=config.attention_head_dim,
+ in_channels=config.in_channels,
+ num_layers=config.num_layers,
+ dropout=config.dropout,
+ norm_num_groups=config.norm_num_groups,
+ cross_attention_dim=config.cross_attention_dim,
+ attention_bias=config.attention_bias,
+ sample_size=config.sample_size,
+ num_vector_embeds=config.num_vector_embeds,
+ activation_fn=config.activation_fn,
+ num_embeds_ada_norm=config.num_embeds_ada_norm,
+ )
+ dual_transformer.transformers[0] = image_transformer
+ dual_transformer.transformers[1] = text_transformer
+
+ self.image_unet.get_submodule(parent_name)[index] = dual_transformer
+ self.image_unet.register_to_config(dual_cross_attention=True)
+
+ def _revert_dual_attention(self):
+ """
+ Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call
+ this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline`
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, DualTransformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+ self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
+
+ self.image_unet.register_to_config(dual_cross_attention=False)
+
+ def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
+ embeds_pooled = encoder_output.text_embeds
+ embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = normalize_embeddings(prompt_embeds)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
+ embeds = self.image_encoder.visual_projection(embeds)
+ embeds_pooled = embeds[:, 0:1]
+ embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
+ pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
+ image_embeddings = self.image_encoder(pixel_values)
+ image_embeddings = normalize_embeddings(image_embeddings)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
+ uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
+ pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
+ negative_prompt_embeds = self.image_encoder(pixel_values)
+ negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and conditional embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, image, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}")
+ if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
+ raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")):
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, DualTransformer2DModel):
+ module.mix_ratio = mix_ratio
+
+ for i, type in enumerate(condition_types):
+ if type == "text":
+ module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings
+ module.transformer_index_for_condition[i] = 1 # use the second (text) transformer
+ else:
+ module.condition_lengths[i] = 257
+ module.transformer_index_for_condition[i] = 0 # use the first (image) transformer
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ image: Union[str, List[str]],
+ text_to_image_strength: float = 0.5,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionDualGuidedPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> text = "a red car in the sun"
+
+ >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe.remove_unused_weights()
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> text_to_image_strength = 0.75
+
+ >>> image = pipe(
+ ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
+ ... ).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ prompt = [prompt] if not isinstance(prompt, list) else prompt
+ image = [image] if not isinstance(image, list) else image
+ batch_size = len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompts
+ prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
+ image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
+ dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1)
+ prompt_types = ("text", "image")
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ dual_prompt_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Combine the attention blocks of the image and text UNets
+ self.set_transformer_params(text_to_image_strength, prompt_types)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
new file mode 100644
index 0000000000000000000000000000000000000000..154548df7542576dd00663ffa5d3e3e45d60da0b
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -0,0 +1,399 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.utils.checkpoint
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ image_feature_extractor: CLIPImageProcessor
+ image_encoder: CLIPVisionModelWithProjection
+ image_unet: UNet2DConditionModel
+ vae: AutoencoderKL
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ image_feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+ self.register_modules(
+ image_feature_extractor=image_feature_extractor,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
+ embeds = self.image_encoder.visual_projection(embeds)
+ embeds_pooled = embeds[:, 0:1]
+ embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
+ return embeds
+
+ if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
+ prompt = list(prompt)
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
+ pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
+ image_embeddings = self.image_encoder(pixel_values)
+ image_embeddings = normalize_embeddings(image_embeddings)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_images: List[str]
+ if negative_prompt is None:
+ uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, PIL.Image.Image):
+ uncond_images = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_images = negative_prompt
+
+ uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
+ pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
+ negative_prompt_embeds = self.image_encoder(pixel_values)
+ negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and conditional embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
+ The image prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionImageVariationPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe(image, generator=generator).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ image_embeddings = self._encode_prompt(
+ image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba5d8451f2eb1f655bbf0d44edaa5b35edb2e87
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -0,0 +1,473 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+import torch.utils.checkpoint
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .modeling_text_unet import UNetFlatConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPImageProcessor
+ text_encoder: CLIPTextModelWithProjection
+ image_unet: UNet2DConditionModel
+ text_unet: UNetFlatConditionModel
+ vae: AutoencoderKL
+ scheduler: KarrasDiffusionSchedulers
+
+ _optional_components = ["text_unet"]
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNetFlatConditionModel,
+ vae: AutoencoderKL,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ if self.text_unet is not None:
+ self._swap_unet_attention_blocks()
+
+ def _swap_unet_attention_blocks(self):
+ """
+ Swap the `Transformer2DModel` blocks between the image and text UNets
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, Transformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+ self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
+ self.text_unet.get_submodule(parent_name)[index],
+ self.image_unet.get_submodule(parent_name)[index],
+ )
+
+ def remove_unused_weights(self):
+ self.register_modules(text_unet=None)
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
+ embeds_pooled = encoder_output.text_embeds
+ embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = normalize_embeddings(prompt_embeds)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionTextToImagePipeline
+ >>> import torch
+
+ >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe.remove_unused_weights()
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0]
+ >>> image.save("./astronaut.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/vq_diffusion/__init__.py b/diffusers/pipelines/vq_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9f14f000648347fe75a5bec0cb45d08c7d2ff9
--- /dev/null
+++ b/diffusers/pipelines/vq_diffusion/__init__.py
@@ -0,0 +1,5 @@
+from ...utils import is_torch_available, is_transformers_available
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline
diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09dde4c595e70e9cb132ed3fb912db4c3bd6df04
Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3aa532e7fc4fd0468cbb26aee2414d8a47f03ea
Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..485bd7fb224c1729864848edc4748f221a6a14da
Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-310.pyc differ
diff --git a/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8626f0de557481b0aa64736c366aefd78de8a955
Binary files /dev/null and b/diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-38.pyc differ
diff --git a/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9147afe127e4b24366249c4a6e058abae9501050
--- /dev/null
+++ b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
@@ -0,0 +1,330 @@
+# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models import ModelMixin, Transformer2DModel, VQModel
+from ...schedulers import VQDiffusionScheduler
+from ...utils import logging
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
+ """
+ Utility class for storing learned text embeddings for classifier free sampling
+ """
+
+ @register_to_config
+ def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
+ super().__init__()
+
+ self.learnable = learnable
+
+ if self.learnable:
+ assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
+ assert length is not None, "learnable=True requires `length` to be set"
+
+ embeddings = torch.zeros(length, hidden_size)
+ else:
+ embeddings = None
+
+ self.embeddings = torch.nn.Parameter(embeddings)
+
+
+class VQDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using VQ Diffusion
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vqvae ([`VQModel`]):
+ Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
+ representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. VQ Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ transformer ([`Transformer2DModel`]):
+ Conditional transformer to denoise the encoded image latents.
+ scheduler ([`VQDiffusionScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ vqvae: VQModel
+ text_encoder: CLIPTextModel
+ tokenizer: CLIPTokenizer
+ transformer: Transformer2DModel
+ learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
+ scheduler: VQDiffusionScheduler
+
+ def __init__(
+ self,
+ vqvae: VQModel,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: Transformer2DModel,
+ scheduler: VQDiffusionScheduler,
+ learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vqvae=vqvae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
+ )
+
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
+ # While CLIP does normalize the pooled output of the text transformer when combining
+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
+ #
+ # CLIP normalizing the pooled output.
+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
+ prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
+
+ # duplicate text embeddings for each generation per prompt
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ if self.learned_classifier_free_sampling_embeddings.learnable:
+ negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings
+ negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
+ else:
+ uncond_tokens = [""] * batch_size
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # See comment for normalizing text embeddings
+ negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ num_inference_steps: int = 100,
+ guidance_scale: float = 5.0,
+ truncation_rate: float = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
+ Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
+ most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
+ `truncation_rate` are set to zero.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor` of shape (batch), *optional*):
+ Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
+ Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
+ be generated of completely masked latent pixels.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict`
+ is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get the initial completely masked latents unless the user supplied it
+
+ latents_shape = (batch_size, self.transformer.num_latent_pixels)
+ if latents is None:
+ mask_class = self.transformer.num_vector_embeds - 1
+ latents = torch.full(latents_shape, mask_class).to(self.device)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
+ raise ValueError(
+ "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
+ f" {self.transformer.num_vector_embeds - 1} (inclusive)."
+ )
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ sample = latents
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the sample if we are doing classifier free guidance
+ latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
+
+ # predict the un-noised image
+ # model_output == `log_p_x_0`
+ model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample
+
+ if do_classifier_free_guidance:
+ model_output_uncond, model_output_text = model_output.chunk(2)
+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
+ model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
+
+ model_output = self.truncate(model_output, truncation_rate)
+
+ # remove `log(0)`'s (`-inf`s)
+ model_output = model_output.clamp(-70)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, sample)
+
+ embedding_channels = self.vqvae.config.vq_embed_dim
+ embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
+ embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
+ image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+ def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
+ """
+ Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
+ lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
+ """
+ sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
+ sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
+ keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
+
+ # Ensure that at least the largest probability is not zeroed out
+ all_true = torch.full_like(keep_mask[:, 0:1, :], True)
+ keep_mask = torch.cat((all_true, keep_mask), dim=1)
+ keep_mask = keep_mask[:, :-1, :]
+
+ keep_mask = keep_mask.gather(1, indices.argsort(1))
+
+ rv = log_p_x_0.clone()
+
+ rv[~keep_mask] = -torch.inf # -inf = log(0)
+
+ return rv
diff --git a/diffusers/schedulers/README.md b/diffusers/schedulers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..31ad27793e34783faabc222adf98691fb396a0d8
--- /dev/null
+++ b/diffusers/schedulers/README.md
@@ -0,0 +1,3 @@
+# Schedulers
+
+For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview).
\ No newline at end of file
diff --git a/diffusers/schedulers/__init__.py b/diffusers/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a07ce4baed20904b85b577aa3e4e38f6a47e945
--- /dev/null
+++ b/diffusers/schedulers/__init__.py
@@ -0,0 +1,92 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ..utils import (
+ OptionalDependencyNotAvailable,
+ is_flax_available,
+ is_scipy_available,
+ is_torch_available,
+ is_torchsde_available,
+)
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_pt_objects import * # noqa F403
+else:
+ from .scheduling_consistency_models import CMStochasticIterativeScheduler
+ from .scheduling_ddim import DDIMScheduler
+ from .scheduling_ddim_inverse import DDIMInverseScheduler
+ from .scheduling_ddim_parallel import DDIMParallelScheduler
+ from .scheduling_ddpm import DDPMScheduler
+ from .scheduling_ddpm_parallel import DDPMParallelScheduler
+ from .scheduling_deis_multistep import DEISMultistepScheduler
+ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
+ from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
+ from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
+ from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
+ from .scheduling_euler_discrete import EulerDiscreteScheduler
+ from .scheduling_heun_discrete import HeunDiscreteScheduler
+ from .scheduling_ipndm import IPNDMScheduler
+ from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
+ from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
+ from .scheduling_karras_ve import KarrasVeScheduler
+ from .scheduling_pndm import PNDMScheduler
+ from .scheduling_repaint import RePaintScheduler
+ from .scheduling_sde_ve import ScoreSdeVeScheduler
+ from .scheduling_sde_vp import ScoreSdeVpScheduler
+ from .scheduling_unclip import UnCLIPScheduler
+ from .scheduling_unipc_multistep import UniPCMultistepScheduler
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+ from .scheduling_vq_diffusion import VQDiffusionScheduler
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_flax_objects import * # noqa F403
+else:
+ from .scheduling_ddim_flax import FlaxDDIMScheduler
+ from .scheduling_ddpm_flax import FlaxDDPMScheduler
+ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
+ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
+ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
+ from .scheduling_pndm_flax import FlaxPNDMScheduler
+ from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
+ from .scheduling_utils_flax import (
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+ )
+
+
+try:
+ if not (is_torch_available() and is_scipy_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
+else:
+ from .scheduling_lms_discrete import LMSDiscreteScheduler
+
+try:
+ if not (is_torch_available() and is_torchsde_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
+else:
+ from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
diff --git a/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc b/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..362762968a25a04396fde2a7a4860bbc20611fe7
Binary files /dev/null and b/diffusers/schedulers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc b/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c17d8944e5467f173b751019543179fda88048b3
Binary files /dev/null and b/diffusers/schedulers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b64b74a63e9fa5b0e3b05398329935278c2a803
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39c1a91c269880d0102d079b01120e37c757c2d0
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_consistency_models.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b975de57edfed7090056442a7e4724cda7684301
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecb5f2378795cdcad48535aafd97f87ce12db7be
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17fec920c13eedf2d7c0328d2e0693333b1db134
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4886057f297b1712501cb5d0ccd2f1e405428d42
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46a27798da4b2df456c8d8d0d015517d87273e7a
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffb0bb1da36378f3765d0ae27c2e2c7591d24750
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddim_parallel.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e1eead3241e8a93088835b30cd6a2d0630396dd
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98baefe397c397858316e924d379623d10c858b9
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd40929b0787c6de89a12a8d6f066be4f6376a63
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..372138392ebc3a92c60f7c0cfdd6f40f60ebc090
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ddpm_parallel.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab24c03522a608a4ef734479ea6bacdcb69992a1
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20bfffc6387c9225771f23b276c711bca2f9462b
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33af0455a39cb19426c580ebc3fedda2f9701dc9
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8c5f9230bee39346c5e36b26e803233f14d1147
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff68dbd7dab59d0e06612679586d0c5573f16d1d
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bacdfacb5b5ff3acc288299a64a779845e02e12
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_inverse.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4eb6787cacfee292a06081ca6cc0798334927bb1
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd562fd887d9cd0013e2362ff18b0be8246876eb
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c655806c4ba18fd049945ce87e6de3b6dfecb0e4
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db174e15986be6e0e45d0666ea49c053d2f1bf30
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfc122f4f88cf9d633c59de7b80eada8b9406814
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38fcc52b2651d95e4a3bc3948a616eddb75f9019
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..546cf379e7807e9fbb93527853aad76a650c68f9
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d24b1be603ff7f5b29281d7af88ab20ebeb2e65a
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7872ee39c72e643020387a52adb3a73a744e34a3
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e7e53da127d8e0eb101871d1817af6783aa67b4
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40852bd769015ee44bbbc88feb72871bbe77f9be
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..088432df988b8a1ea3f02c00960e7f1a4f0c6174
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21861dc07327a721cdd0e2b592de288e3f73ff1f
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96998ea9b7d978917c0d4d73c24f37c28005242b
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f86170ad577b0e9effb01f1d92e959e40307f21c
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bdd88b6d6b9915c7b7ad66c2a314c7470ab6bd7
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..950f9b500065d717bd4783cb98b1b52d4dcffba1
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09c016bf022cb8dda80ef0d5f12770d04d11784d
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f067a472f53466953c26bf0d6dff2cbf22cd8594
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac775ea4d4a39e8341e3c085e4c2ff1315a2160
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_pndm.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..563f6ce8a3584cb3c4952f4a6e2d672dd97cbc06
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b2a26de7b3b3e32c60d2288b51b074f9672d12b
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_repaint.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54a6f61a4841dcde2cc317da46dcbf60b9ed126c
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db247291040dba901942c21a4fa31f80aa19ec25
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aabb3c5ec409487e0379eb0fb452f48e96e5ea8d
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c56cf91e2c0521e7506ba243ea2d15bbcf47dac4
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..703c3a8c734fc7c6db66110c8f08b07c4111f448
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e759eb5e508a28c65538856ad8d53e4303b9901
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unclip.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..161c265132aa504908d5f99641a2551223fed859
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..095240969bda27cf9fd8c904d513397a3cb2b7e3
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e556205119a8b1ddb53249243d9eeb18483b3941
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7266bc74eca09ddba385bfc43382be45d8e5128a
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_utils.cpython-38.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f125b214f79a89ecf89cc177a9cd4d3d2c82a2a2
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-310.pyc differ
diff --git a/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8ed6a5124fddce0c1e78705e8663c9d06998f45
Binary files /dev/null and b/diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-38.pyc differ
diff --git a/diffusers/schedulers/scheduling_consistency_models.py b/diffusers/schedulers/scheduling_consistency_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb296054d65b804af281dc99d940c8f0ba50e01b
--- /dev/null
+++ b/diffusers/schedulers/scheduling_consistency_models.py
@@ -0,0 +1,380 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, logging, randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class CMStochasticIterativeSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the
+ paper [1].
+
+ [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
+ https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based
+ Generative Models." https://arxiv.org/abs/2206.00364
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ sigma_min (`float`):
+ Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation.
+ sigma_max (`float`):
+ Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation.
+ sigma_data (`float`):
+ The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the
+ original implementation, which is also the original value suggested in the EDM paper.
+ s_noise (`float`):
+ The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
+ 1.011]. This was set to 1.0 in the original implementation.
+ rho (`float`):
+ The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was
+ set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper.
+ clip_denoised (`bool`):
+ Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`.
+ timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
+ Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing
+ order.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 40,
+ sigma_min: float = 0.002,
+ sigma_max: float = 80.0,
+ sigma_data: float = 0.5,
+ s_noise: float = 1.0,
+ rho: float = 7.0,
+ clip_denoised: bool = True,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ ramp = np.linspace(0, 1, num_train_timesteps)
+ sigmas = self._convert_to_karras(ramp)
+ timesteps = self.sigma_to_t(sigmas)
+
+ # setable values
+ self.num_inference_steps = None
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps)
+ self.custom_timesteps = False
+ self.is_scale_input_called = False
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+ return indices.item()
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ # Get sigma corresponding to timestep
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_idx = self.index_for_timestep(timestep)
+ sigma = self.sigmas[step_idx]
+
+ sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
+ """
+ Gets scaled timesteps from the Karras sigmas, for input to the consistency model.
+
+ Args:
+ sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas
+ Returns:
+ `float` or `np.ndarray`: scaled input timestep or scaled input timestep array
+ """
+ if not isinstance(sigmas, np.ndarray):
+ sigmas = np.array(sigmas, dtype=np.float64)
+
+ timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44)
+
+ return timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, optional):
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
+ must be `None`.
+ """
+ if num_inference_steps is None and timesteps is None:
+ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
+
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.")
+
+ # Follow DDPMScheduler custom timesteps logic
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ self.custom_timesteps = False
+
+ # Map timesteps to Karras sigmas directly for multistep sampling
+ # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
+ num_train_timesteps = self.config.num_train_timesteps
+ ramp = timesteps[::-1].copy()
+ ramp = ramp / (num_train_timesteps - 1)
+ sigmas = self._convert_to_karras(ramp)
+ timesteps = self.sigma_to_t(sigmas)
+
+ sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ # Modified _convert_to_karras implementation that takes in ramp as argument
+ def _convert_to_karras(self, ramp):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = self.config.sigma_min
+ sigma_max: float = self.config.sigma_max
+
+ rho = self.config.rho
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def get_scalings(self, sigma):
+ sigma_data = self.config.sigma_data
+
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def get_scalings_for_boundary_condition(self, sigma):
+ """
+ Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper.
+ This enforces the consistency model boundary condition.
+
+ Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min.
+
+ Args:
+ sigma (`torch.FloatTensor`):
+ The current sigma in the Karras sigma schedule.
+ Returns:
+ `tuple`:
+ A two-element tuple where c_skip (which weights the current sample) is the first element and c_out
+ (which weights the consistency model output) is the second element.
+ """
+ sigma_min = self.config.sigma_min
+ sigma_data = self.config.sigma_data
+
+ c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2)
+ c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator (`torch.Generator`, *optional*): Random number generator.
+ return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ f" `{self.__class__}.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ sigma_min = self.config.sigma_min
+ sigma_max = self.config.sigma_max
+
+ step_index = self.index_for_timestep(timestep)
+
+ # sigma_next corresponds to next_t in original implementation
+ sigma = self.sigmas[step_index]
+ if step_index + 1 < self.config.num_train_timesteps:
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # Set sigma_next to sigma_min
+ sigma_next = self.sigmas[-1]
+
+ # Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
+
+ # 1. Denoise model output using boundary conditions
+ denoised = c_out * model_output + c_skip * sample
+ if self.config.clip_denoised:
+ denoised = denoised.clamp(-1, 1)
+
+ # 2. Sample z ~ N(0, s_noise^2 * I)
+ # Noise is not used for onestep sampling.
+ if len(self.timesteps) > 1:
+ noise = randn_tensor(
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ )
+ else:
+ noise = torch.zeros_like(model_output)
+ z = noise * self.config.s_noise
+
+ sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max)
+
+ # 3. Return noisy sample
+ # tau = sigma_hat, eps = sigma_min
+ prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddim.py b/diffusers/schedulers/scheduling_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..a93255ca600ef34da1b6c1691c4c5e9f7f86c2ed
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim.py
@@ -0,0 +1,515 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, default `False`):
+ whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
+ This can enable the model to generate very bright and dark samples instead of limiting it to samples with
+ medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+ generator: random number generator.
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
+ can directly provide the noise for the variance itself. This is useful for methods such as
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ variance = std_dev_t * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddim_flax.py b/diffusers/schedulers/scheduling_ddim_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..db248c33077bf502e31cb2ab97141744b828b514
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim_flax.py
@@ -0,0 +1,305 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ CommonSchedulerState,
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ add_noise_common,
+ get_velocity_common,
+)
+
+
+@flax.struct.dataclass
+class DDIMSchedulerState:
+ common: CommonSchedulerState
+ final_alpha_cumprod: jnp.ndarray
+
+ # setable values
+ init_noise_sigma: jnp.ndarray
+ timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ @classmethod
+ def create(
+ cls,
+ common: CommonSchedulerState,
+ final_alpha_cumprod: jnp.ndarray,
+ init_noise_sigma: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ):
+ return cls(
+ common=common,
+ final_alpha_cumprod=final_alpha_cumprod,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+
+@dataclass
+class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
+ state: DDIMSchedulerState
+
+
+class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
+ `v-prediction` is not supported for this scheduler.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ the `dtype` used for params and computation.
+ """
+
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
+
+ dtype: jnp.dtype
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ final_alpha_cumprod = (
+ jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
+ )
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
+
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
+
+ return DDIMSchedulerState.create(
+ common=common,
+ final_alpha_cumprod=final_alpha_cumprod,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+ def scale_model_input(
+ self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> DDIMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DDIMSchedulerState`):
+ the `FlaxDDIMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # rounding to avoid issues when num_inference_step is power of 3
+ timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ )
+
+ def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
+ alpha_prod_t = state.common.alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(
+ prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
+ )
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def step(
+ self,
+ state: DDIMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ eta: float = 0.0,
+ return_dict: bool = True,
+ ) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
+
+ Returns:
+ [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+
+ alphas_cumprod = state.common.alphas_cumprod
+ final_alpha_cumprod = state.final_alpha_cumprod
+
+ # 2. compute alphas, betas
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(state, timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def add_noise(
+ self,
+ state: DDIMSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return add_noise_common(state.common, original_samples, noise, timesteps)
+
+ def get_velocity(
+ self,
+ state: DDIMSchedulerState,
+ sample: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return get_velocity_common(state.common, sample, noise, timesteps)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddim_inverse.py b/diffusers/schedulers/scheduling_ddim_inverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..9db83eb992022c80b4fd1cc3769eedca219c96f4
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -0,0 +1,349 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+from diffusers.utils import BaseOutput, deprecate
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_zero (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`,
+ otherwise it uses the value of alpha at step `num_train_timesteps - 1`.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha
+ product.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, default `False`):
+ whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
+ This can enable the model to generate very bright and dark samples instead of limiting it to samples with
+ medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ order = 1
+ ignore_for_config = ["kwargs"]
+ _deprecated_kwargs = ["set_alpha_to_zero"]
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ clip_sample_range: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ **kwargs,
+ ):
+ if kwargs.get("set_alpha_to_zero", None) is not None:
+ deprecation_message = (
+ "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead."
+ )
+ deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False)
+ set_alpha_to_one = kwargs["set_alpha_to_zero"]
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in inverted ddim, we are looking into the next alphas_cumprod
+ # For the initial step, there is no current alphas_cumprod, and the index is out of bounds
+ # `set_alpha_to_one` decides whether we set this parameter simply to one
+ # in this case, self.step() just output the predicted noise
+ # or whether we use the initial alpha used in training the diffusion model.
+ self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
+
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ # Roll timesteps array by one to reflect reversed origin and destination semantics for each step
+ timesteps = np.roll(timesteps, 1)
+ timesteps[0] = int(timesteps[1] - step_ratio)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ # 1. get previous step value (=t+1)
+ prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ # change original implementation to exactly match noise levels for analogous forward process
+ alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
+
+ # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if not return_dict:
+ return (prev_sample, pred_original_sample)
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddim_parallel.py b/diffusers/schedulers/scheduling_ddim_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..db3ea0e1cca55f88d0a81d0311158929516cb038
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -0,0 +1,642 @@
+# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput
+class DDIMParallelSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, default `False`):
+ whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
+ This can enable the model to generate very bright and dark samples instead of limiting it to samples with
+ medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+ _is_ode_scheduler = True
+
+ @register_to_config
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep=None):
+ if prev_timestep is None:
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def _batch_get_variance(self, t, prev_t):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
+ alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMParallelSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+ generator: random number generator.
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
+ can directly provide the noise for the variance itself. This is useful for methods such as
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+ return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ variance = std_dev_t * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def batch_step_no_noise(
+ self,
+ model_output: torch.FloatTensor,
+ timesteps: List[int],
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ ) -> torch.FloatTensor:
+ """
+ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
+ Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
+ is pre-sampled by the pipeline.
+
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timesteps (`List[int]`):
+ current discrete timesteps in the diffusion chain. This is now a list of integers.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+
+ Returns:
+ `torch.FloatTensor`: sample tensor at previous timestep.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ assert eta == 0.0
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ t = timesteps
+ prev_t = t - self.config.num_train_timesteps // self.num_inference_steps
+
+ t = t.view(-1, *([1] * (model_output.ndim - 1)))
+ prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
+
+ # 1. compute alphas, betas
+ self.alphas_cumprod = self.alphas_cumprod.to(model_output.device)
+ self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device)
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
+ alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ return prev_sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddpm.py b/diffusers/schedulers/scheduling_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1b7d7aaa9c22a3a768d1aed131794e810400936
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm.py
@@ -0,0 +1,513 @@
+# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+class DDPMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.custom_timesteps = False
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.variance_type = variance_type
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`Optional[int]`):
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps are moved to.
+ custom_timesteps (`List[int]`, optional):
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
+ must be `None`.
+
+ """
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ self.custom_timesteps = False
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ prev_t = self.previous_timestep(t)
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
+
+ # we always take the log of variance, so clamp it to ensure it's not 0
+ variance = torch.clamp(variance, min=1e-20)
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small":
+ variance = variance
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = torch.log(variance)
+ variance = torch.exp(0.5 * variance)
+ elif variance_type == "fixed_large":
+ variance = current_beta_t
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = torch.log(current_beta_t)
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = torch.log(variance)
+ max_log = torch.log(current_beta_t)
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ prev_t = self.previous_timestep(t)
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
+ current_beta_t = 1 - current_alpha_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for the DDPMScheduler."
+ )
+
+ # 3. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ device = model_output.device
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
+ )
+ if self.variance_type == "fixed_small_log":
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
+ elif self.variance_type == "learned_range":
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
+ variance = torch.exp(0.5 * variance) * variance_noise
+ else:
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ def previous_timestep(self, timestep):
+ if self.custom_timesteps:
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
+ if index == self.timesteps.shape[0] - 1:
+ prev_t = torch.tensor(-1)
+ else:
+ prev_t = self.timesteps[index + 1]
+ else:
+ num_inference_steps = (
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
+ )
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
+
+ return prev_t
diff --git a/diffusers/schedulers/scheduling_ddpm_flax.py b/diffusers/schedulers/scheduling_ddpm_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..529d2bd03a75403e298ec7a30808689a48cf5301
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -0,0 +1,299 @@
+# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ CommonSchedulerState,
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ add_noise_common,
+ get_velocity_common,
+)
+
+
+@flax.struct.dataclass
+class DDPMSchedulerState:
+ common: CommonSchedulerState
+
+ # setable values
+ init_noise_sigma: jnp.ndarray
+ timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ @classmethod
+ def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
+ return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
+
+
+@dataclass
+class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
+ state: DDPMSchedulerState
+
+
+class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
+ `v-prediction` is not supported for this scheduler.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ the `dtype` used for params and computation.
+ """
+
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
+
+ dtype: jnp.dtype
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ prediction_type: str = "epsilon",
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
+
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
+
+ return DDPMSchedulerState.create(
+ common=common,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+ def scale_model_input(
+ self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> DDPMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DDIMSchedulerState`):
+ the `FlaxDDPMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # rounding to avoid issues when num_inference_step is power of 3
+ timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ )
+
+ def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = state.common.alphas_cumprod[t]
+ alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small":
+ variance = jnp.clip(variance, a_min=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = jnp.log(jnp.clip(variance, a_min=1e-20))
+ elif variance_type == "fixed_large":
+ variance = state.common.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = jnp.log(state.common.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = state.common.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ state: DDPMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ key: Optional[jax.random.KeyArray] = None,
+ return_dict: bool = True,
+ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ key (`jax.random.KeyArray`): a PRNG key.
+ return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
+
+ Returns:
+ [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ if key is None:
+ key = jax.random.PRNGKey(0)
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = state.common.alphas_cumprod[t]
+ alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
+ " for the FlaxDDPMScheduler."
+ )
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t
+ current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ def random_variance():
+ split_key = jax.random.split(key, num=1)
+ noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
+ return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype))
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample, state)
+
+ return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
+
+ def add_noise(
+ self,
+ state: DDPMSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return add_noise_common(state.common, original_samples, noise, timesteps)
+
+ def get_velocity(
+ self,
+ state: DDPMSchedulerState,
+ sample: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return get_velocity_common(state.common, sample, noise, timesteps)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddpm_parallel.py b/diffusers/schedulers/scheduling_ddpm_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92e175877d24057e49bf405e88185fd4297e6d2
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -0,0 +1,604 @@
+# Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput
+class DDPMParallelSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+ _is_ode_scheduler = False
+
+ @register_to_config
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.custom_timesteps = False
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.variance_type = variance_type
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`Optional[int]`):
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps are moved to.
+ custom_timesteps (`List[int]`, optional):
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
+ must be `None`.
+
+ """
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ self.custom_timesteps = False
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ prev_t = self.previous_timestep(t)
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
+
+ # we always take the log of variance, so clamp it to ensure it's not 0
+ variance = torch.clamp(variance, min=1e-20)
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small":
+ variance = variance
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = torch.log(variance)
+ variance = torch.exp(0.5 * variance)
+ elif variance_type == "fixed_large":
+ variance = current_beta_t
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = torch.log(current_beta_t)
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = torch.log(variance)
+ max_log = torch.log(current_beta_t)
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[DDPMParallelSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ prev_t = self.previous_timestep(t)
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
+ current_beta_t = 1 - current_alpha_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for the DDPMScheduler."
+ )
+
+ # 3. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ device = model_output.device
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
+ )
+ if self.variance_type == "fixed_small_log":
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
+ elif self.variance_type == "learned_range":
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
+ variance = torch.exp(0.5 * variance) * variance_noise
+ else:
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ def batch_step_no_noise(
+ self,
+ model_output: torch.FloatTensor,
+ timesteps: List[int],
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
+ Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
+ is pre-sampled by the pipeline.
+
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timesteps (`List[int]`):
+ current discrete timesteps in the diffusion chain. This is now a list of integers.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: sample tensor at previous timestep.
+ """
+ t = timesteps
+ num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
+ prev_t = t - self.config.num_train_timesteps // num_inference_steps
+
+ t = t.view(-1, *([1] * (model_output.ndim - 1)))
+ prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ pass
+
+ # 1. compute alphas, betas
+ self.alphas_cumprod = self.alphas_cumprod.to(model_output.device)
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
+ alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
+ current_beta_t = 1 - current_alpha_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for the DDPMParallelScheduler."
+ )
+
+ # 3. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ return pred_prev_sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
+ def previous_timestep(self, timestep):
+ if self.custom_timesteps:
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
+ if index == self.timesteps.shape[0] - 1:
+ prev_t = torch.tensor(-1)
+ else:
+ prev_t = self.timesteps[index + 1]
+ else:
+ num_inference_steps = (
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
+ )
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
+
+ return prev_t
diff --git a/diffusers/schedulers/scheduling_deis_multistep.py b/diffusers/schedulers/scheduling_deis_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..36947294922b6cc0ecdc5bf7dc9c0772a056d03a
--- /dev/null
+++ b/diffusers/schedulers/scheduling_deis_multistep.py
@@ -0,0 +1,568 @@
+# Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info
+# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the
+ polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification
+ enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More
+ variants of DEIS can be found in https://github.com/qsh-zh/deis.
+
+ Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1`
+ reduces to DDIM.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set `thresholding=True` to use the dynamic thresholding.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and
+ `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
+ or `v-prediction`.
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`
+ algorithm_type (`str`, default `deis`):
+ the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
+ the future
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "deis",
+ solver_type: str = "logrho",
+ lower_order_final: bool = True,
+ use_karras_sigmas: Optional[bool] = False,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DEIS
+ if algorithm_type not in ["deis"]:
+ if algorithm_type in ["dpmsolver", "dpmsolver++"]:
+ self.register_to_config(algorithm_type="deis")
+ else:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+
+ if solver_type not in ["logrho"]:
+ if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
+ self.register_to_config(solver_type="logrho")
+ else:
+ raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if self.config.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
+
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # when num_inference_steps == num_train_timesteps, we can end up with
+ # duplicates in timesteps.
+ _, unique_indices = np.unique(timesteps, return_index=True)
+ timesteps = timesteps[np.sort(unique_indices)]
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type that the algorithm DEIS needs.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+ if self.config.prediction_type == "epsilon":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DEISMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ if self.config.algorithm_type == "deis":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ return (sample - alpha_t * x0_pred) / sigma_t
+ else:
+ raise NotImplementedError("only support log-rho multistep deis now")
+
+ def deis_first_order_update(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the first-order DEIS (equivalent to DDIM).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
+ sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "deis":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ else:
+ raise NotImplementedError("only support log-rho multistep deis now")
+ return x_t
+
+ def multistep_deis_second_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the second-order multistep DEIS.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
+ sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
+
+ rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
+
+ if self.config.algorithm_type == "deis":
+
+ def ind_fn(t, b, c):
+ # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
+ return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
+
+ coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
+ coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
+
+ x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
+ return x_t
+ else:
+ raise NotImplementedError("only support log-rho multistep deis now")
+
+ def multistep_deis_third_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the third-order multistep DEIS.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
+ sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
+ rho_t, rho_s0, rho_s1, rho_s2 = (
+ sigma_t / alpha_t,
+ sigma_s0 / alpha_s0,
+ sigma_s1 / alpha_s1,
+ simga_s2 / alpha_s2,
+ )
+
+ if self.config.algorithm_type == "deis":
+
+ def ind_fn(t, b, c, d):
+ # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
+ numerator = t * (
+ np.log(c) * (np.log(d) - np.log(t) + 1)
+ - np.log(d) * np.log(t)
+ + np.log(d)
+ + np.log(t) ** 2
+ - 2 * np.log(t)
+ + 2
+ )
+ denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
+ return numerator / denominator
+
+ coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
+ coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
+ coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)
+
+ x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)
+
+ return x_t
+ else:
+ raise NotImplementedError("only support log-rho multistep deis now")
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the multistep DEIS.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+ lower_order_final = (
+ (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+ lower_order_second = (
+ (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ timestep_list = [self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_deis_second_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+ else:
+ timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_deis_third_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/diffusers/schedulers/scheduling_dpmsolver_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7516fa601e17cdd5661039c181804d687a66f0e
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -0,0 +1,749 @@
+# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
+ samples, and it can generate quite good samples even in only 10 steps.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse
+ diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the
+ second-order `sde-dpmsolver++`.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ lambda_min_clipped (`float`, default `-inf`):
+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
+ cosine (squaredcos_cap_v2) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ use_karras_sigmas: Optional[bool] = False,
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ # Clipping the minimum of all lambda(t) for numerical stability.
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = last_timestep // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if self.config.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
+
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # when num_inference_steps == num_train_timesteps, we can end up with
+ # duplicates in timesteps.
+ _, unique_indices = np.unique(timesteps, return_index=True)
+ timesteps = timesteps[np.sort(unique_indices)]
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ model_output = model_output[:, :3]
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ epsilon = model_output[:, :3]
+ else:
+ epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
+
+ return epsilon
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = (
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ x_t = (
+ (alpha_t / alpha_s) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ One step for the second-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.algorithm_type == "sde-dpmsolver":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the third-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ self.lambda_t[t],
+ self.lambda_t[s0],
+ self.lambda_t[s1],
+ self.lambda_t[s2],
+ )
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the multistep DPM-Solver.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+ lower_order_final = (
+ (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+ lower_order_second = (
+ (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ else:
+ noise = None
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(
+ model_output, timestep, prev_timestep, sample, noise=noise
+ )
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ timestep_list = [self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_second_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
+ )
+ else:
+ timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_third_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4ee67a7f5dbf8384eaedc0ede322284a413edd
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
@@ -0,0 +1,622 @@
+# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ CommonSchedulerState,
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ add_noise_common,
+)
+
+
+@flax.struct.dataclass
+class DPMSolverMultistepSchedulerState:
+ common: CommonSchedulerState
+ alpha_t: jnp.ndarray
+ sigma_t: jnp.ndarray
+ lambda_t: jnp.ndarray
+
+ # setable values
+ init_noise_sigma: jnp.ndarray
+ timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ # running values
+ model_outputs: Optional[jnp.ndarray] = None
+ lower_order_nums: Optional[jnp.int32] = None
+ prev_timestep: Optional[jnp.int32] = None
+ cur_sample: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(
+ cls,
+ common: CommonSchedulerState,
+ alpha_t: jnp.ndarray,
+ sigma_t: jnp.ndarray,
+ lambda_t: jnp.ndarray,
+ init_noise_sigma: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ):
+ return cls(
+ common=common,
+ alpha_t=alpha_t,
+ sigma_t=sigma_t,
+ lambda_t=lambda_t,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+
+@dataclass
+class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
+ state: DPMSolverMultistepSchedulerState
+
+
+class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
+ samples, and it can generate quite good samples even in only 10 steps.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
+ or `v-prediction`.
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
+ algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
+ https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
+ sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ the `dtype` used for params and computation.
+ """
+
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
+
+ dtype: jnp.dtype
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ # Currently we only support VP-type noise schedule
+ alpha_t = jnp.sqrt(common.alphas_cumprod)
+ sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
+ lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)
+
+ # settings for DPM-Solver
+ if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
+ raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}")
+ if self.config.solver_type not in ["midpoint", "heun"]:
+ raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}")
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
+
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
+
+ return DPMSolverMultistepSchedulerState.create(
+ common=common,
+ alpha_t=alpha_t,
+ sigma_t=sigma_t,
+ lambda_t=lambda_t,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+ def set_timesteps(
+ self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
+ ) -> DPMSolverMultistepSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ shape (`Tuple`):
+ the shape of the samples to be generated.
+ """
+
+ timesteps = (
+ jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .astype(jnp.int32)
+ )
+
+ # initial running values
+
+ model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype)
+ lower_order_nums = jnp.int32(0)
+ prev_timestep = jnp.int32(-1)
+ cur_sample = jnp.zeros(shape, dtype=self.dtype)
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ model_outputs=model_outputs,
+ lower_order_nums=lower_order_nums,
+ prev_timestep=prev_timestep,
+ cur_sample=cur_sample,
+ )
+
+ def convert_model_output(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the converted model output.
+ """
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type == "dpmsolver++":
+ if self.config.prediction_type == "epsilon":
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
+ dynamic_max_val = jnp.percentile(
+ jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
+ )
+ dynamic_max_val = jnp.maximum(
+ dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
+ )
+ x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
+ return x0_pred
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type == "dpmsolver":
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
+ )
+
+ def dpm_solver_first_order_update(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ prev_timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0 = prev_timestep, timestep
+ m0 = model_output
+ lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
+ alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
+ sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output_list: jnp.ndarray,
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ One step for the second-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[jnp.ndarray]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
+ alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
+ sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_order_update(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output_list: jnp.ndarray,
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ One step for the third-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[jnp.ndarray]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ state.lambda_t[t],
+ state.lambda_t[s0],
+ state.lambda_t[s1],
+ state.lambda_t[s2],
+ )
+ alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
+ sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def step(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
+ from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class
+
+ Returns:
+ [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
+ step_index = step_index[0]
+
+ prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1])
+
+ model_output = self.convert_model_output(state, model_output, timestep, sample)
+
+ model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
+ model_outputs_new = model_outputs_new.at[-1].set(model_output)
+ state = state.replace(
+ model_outputs=model_outputs_new,
+ prev_timestep=prev_timestep,
+ cur_sample=sample,
+ )
+
+ def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ return self.dpm_solver_first_order_update(
+ state,
+ state.model_outputs[-1],
+ state.timesteps[step_index],
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]])
+ return self.multistep_dpm_solver_second_order_update(
+ state,
+ state.model_outputs,
+ timestep_list,
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ timestep_list = jnp.array(
+ [
+ state.timesteps[step_index - 2],
+ state.timesteps[step_index - 1],
+ state.timesteps[step_index],
+ ]
+ )
+ return self.multistep_dpm_solver_third_order_update(
+ state,
+ state.model_outputs,
+ timestep_list,
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ step_2_output = step_2(state)
+ step_3_output = step_3(state)
+
+ if self.config.solver_order == 2:
+ return step_2_output
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
+ return jax.lax.select(
+ state.lower_order_nums < 2,
+ step_2_output,
+ jax.lax.select(
+ step_index == len(state.timesteps) - 2,
+ step_2_output,
+ step_3_output,
+ ),
+ )
+ else:
+ return jax.lax.select(
+ state.lower_order_nums < 2,
+ step_2_output,
+ step_3_output,
+ )
+
+ step_1_output = step_1(state)
+ step_23_output = step_23(state)
+
+ if self.config.solver_order == 1:
+ prev_sample = step_1_output
+
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
+ prev_sample = jax.lax.select(
+ state.lower_order_nums < 1,
+ step_1_output,
+ jax.lax.select(
+ step_index == len(state.timesteps) - 1,
+ step_1_output,
+ step_23_output,
+ ),
+ )
+
+ else:
+ prev_sample = jax.lax.select(
+ state.lower_order_nums < 1,
+ step_1_output,
+ step_23_output,
+ )
+
+ state = state.replace(
+ lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def scale_model_input(
+ self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def add_noise(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return add_noise_common(state.common, original_samples, noise, timesteps)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5c98a551d665a05d4cbab8ccbdef6785fa2ed09
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -0,0 +1,707 @@
+# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`].
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or
+ `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and
+ the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use
+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ lambda_min_clipped (`float`, default `-inf`):
+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
+ cosine (squaredcos_cap_v2) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ use_karras_sigmas: Optional[bool] = False,
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self.use_karras_sigmas = use_karras_sigmas
+
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ # Clipping the minimum of all lambda(t) for numerical stability.
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item()
+ self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', "
+ "'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if self.config.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = timesteps.copy().astype(np.int64)
+
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # when num_inference_steps == num_train_timesteps, we can end up with
+ # duplicates in timesteps.
+ _, unique_indices = np.unique(timesteps, return_index=True)
+ timesteps = timesteps[np.sort(unique_indices)]
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ model_output = model_output[:, :3]
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned", "learned_range"]:
+ epsilon = model_output[:, :3]
+ else:
+ epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
+
+ return epsilon
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ elif "sde" in self.config.algorithm_type:
+ raise NotImplementedError(
+ f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
+ )
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ One step for the second-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ elif "sde" in self.config.algorithm_type:
+ raise NotImplementedError(
+ f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}."
+ )
+ return x_t
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the third-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ self.lambda_t[t],
+ self.lambda_t[s0],
+ self.lambda_t[s1],
+ self.lambda_t[s2],
+ )
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the multistep DPM-Solver.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+ prev_timestep = (
+ self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+ )
+ lower_order_final = (
+ (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+ lower_order_second = (
+ (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ else:
+ noise = None
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(
+ model_output, timestep, prev_timestep, sample, noise=noise
+ )
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ timestep_list = [self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_second_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
+ )
+ else:
+ timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_third_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_sde.py b/diffusers/schedulers/scheduling_dpmsolver_sde.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31e97b6965169823634afe8984866a9f7d03ba3
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -0,0 +1,509 @@
+# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torchsde
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get("w0", torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2**63 - 1, []).item()
+ self.batched = True
+ try:
+ assert len(seed) == x.shape[0]
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
+ with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion
+ implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ noise_sampler_seed (`int`, *optional*, defaults to `None`):
+ The random seed to use for the noise sampler. If `None`, a random seed will be generated.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ use_karras_sigmas: Optional[bool] = False,
+ noise_sampler_seed: Optional[int] = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+ self.use_karras_sigmas = use_karras_sigmas
+ self.noise_sampler = None
+ self.noise_sampler_seed = noise_sampler_seed
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(self._index_counter) == 0:
+ pos = 1 if len(indices) > 1 else 0
+ else:
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ pos = self._index_counter[timestep_int]
+
+ return indices[pos].item()
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ sigma = self.sigmas[step_index]
+ sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
+ sample = sample / ((sigma_input**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+
+ if self.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
+
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
+
+ timesteps = torch.from_numpy(timesteps)
+ second_order_timesteps = torch.from_numpy(second_order_timesteps)
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
+ timesteps[1::2] = second_order_timesteps
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = timesteps.to(device, dtype=torch.float32)
+ else:
+ self.timesteps = timesteps.to(device=device)
+
+ # empty first order variables
+ self.sample = None
+ self.mid_point_sigma = None
+
+ # for exp beta schedules, such as the one for `pipeline_shap_e.py`
+ # we need an index counter
+ self._index_counter = defaultdict(int)
+
+ def _second_order_timesteps(self, sigmas, log_sigmas):
+ def sigma_fn(_t):
+ return np.exp(-_t)
+
+ def t_fn(_sigma):
+ return -np.log(_sigma)
+
+ midpoint_ratio = 0.5
+ t = t_fn(sigmas)
+ delta_time = np.diff(t)
+ t_proposed = t[:-1] + delta_time * midpoint_ratio
+ sig_proposed = sigma_fn(t_proposed)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed])
+ return timesteps
+
+ # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, self.num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ @property
+ def state_in_first_order(self):
+ return self.sample is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ s_noise: float = 1.0,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model.
+ timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain.
+ sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process.
+ return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True.
+ s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0.
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ # advance index counter by 1
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ self._index_counter[timestep_int] += 1
+
+ # Create a noise sampler if it hasn't been created yet
+ if self.noise_sampler is None:
+ min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
+ self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
+
+ # Define functions to compute sigma and t from each other
+ def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor:
+ return _t.neg().exp()
+
+ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor:
+ return _sigma.log().neg()
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # 2nd order
+ sigma = self.sigmas[step_index - 1]
+ sigma_next = self.sigmas[step_index]
+
+ # Set the midpoint and step size for the current step
+ midpoint_ratio = 0.5
+ t, t_next = t_fn(sigma), t_fn(sigma_next)
+ delta_time = t_next - t
+ t_proposed = t + delta_time * midpoint_ratio
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ elif self.config.prediction_type == "sample":
+ raise NotImplementedError("prediction_type not implemented yet: sample")
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if sigma_next == 0:
+ derivative = (sample - pred_original_sample) / sigma
+ dt = sigma_next - sigma
+ prev_sample = sample + derivative * dt
+ else:
+ if self.state_in_first_order:
+ t_next = t_proposed
+ else:
+ sample = self.sample
+
+ sigma_from = sigma_fn(t)
+ sigma_to = sigma_fn(t_next)
+ sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ ancestral_t = t_fn(sigma_down)
+ prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
+ t - ancestral_t
+ ).expm1() * pred_original_sample
+ prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up
+
+ if self.state_in_first_order:
+ # store for 2nd order step
+ self.sample = sample
+ self.mid_point_sigma = sigma_fn(t_next)
+ else:
+ # free for "first order mode"
+ self.sample = None
+ self.mid_point_sigma = None
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
new file mode 100644
index 0000000000000000000000000000000000000000..93975a27fc6e3899c009b5576ed74753ea62abbb
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -0,0 +1,737 @@
+# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import logging
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
+ samples, and it can generate quite good samples even in only 10 steps.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Currently, we support the singlestep DPM-Solver for both noise prediction models and data prediction models. We
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
+ or `v-prediction`.
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
+ algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
+ https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
+ sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
+ this to use up all the function evaluations.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ lambda_min_clipped (`float`, default `-inf`):
+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
+ cosine (squaredcos_cap_v2) noise schedule.
+ variance_type (`str`, *optional*):
+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
+ diffusion ODEs.
+
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ use_karras_sigmas: Optional[bool] = False,
+ lambda_min_clipped: float = -float("inf"),
+ variance_type: Optional[str] = None,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
+ if algorithm_type == "deis":
+ self.register_to_config(algorithm_type="dpmsolver++")
+ else:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.sample = None
+ self.order_list = self.get_order_list(num_train_timesteps)
+
+ def get_order_list(self, num_inference_steps: int) -> List[int]:
+ """
+ Computes the solver order at each time step.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ steps = num_inference_steps
+ order = self.config.solver_order
+ if self.config.lower_order_final:
+ if order == 3:
+ if steps % 3 == 0:
+ orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1]
+ elif steps % 3 == 1:
+ orders = [1, 2, 3] * (steps // 3) + [1]
+ else:
+ orders = [1, 2, 3] * (steps // 3) + [1, 2]
+ elif order == 2:
+ if steps % 2 == 0:
+ orders = [1, 2] * (steps // 2)
+ else:
+ orders = [1, 2] * (steps // 2) + [1]
+ elif order == 1:
+ orders = [1] * steps
+ else:
+ if order == 3:
+ orders = [1, 2, 3] * (steps // 3)
+ elif order == 2:
+ orders = [1, 2] * (steps // 2)
+ elif order == 1:
+ orders = [1] * steps
+ return orders
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+ # Clipping the minimum of all lambda(t) for numerical stability.
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if self.config.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
+
+ self.sigmas = torch.from_numpy(sigmas)
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ self.model_outputs = [None] * self.config.solver_order
+ self.sample = None
+
+ if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
+ logger.warn(
+ "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
+ )
+ self.register_to_config(lower_order_final=True)
+
+ self.order_list = self.get_order_list(num_inference_steps)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type == "dpmsolver++":
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned_range"]:
+ model_output = model_output[:, :3]
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverSinglestepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type == "dpmsolver":
+ if self.config.prediction_type == "epsilon":
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
+ if self.config.variance_type in ["learned_range"]:
+ model_output = model_output[:, :3]
+ return model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverSinglestepScheduler."
+ )
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ return x_t
+
+ def singlestep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the second-order singlestep DPM-Solver.
+
+ It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
+ alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
+ sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
+ h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m1, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s1) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s1) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s1) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s1) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ return x_t
+
+ def singlestep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the third-order singlestep DPM-Solver.
+
+ It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ self.lambda_t[t],
+ self.lambda_t[s0],
+ self.lambda_t[s1],
+ self.lambda_t[s2],
+ )
+ alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
+ sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
+ h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m2
+ D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2)
+ D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
+ D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s2) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s2) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s2) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s2) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ order: int,
+ ) -> torch.FloatTensor:
+ """
+ One step for the singlestep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ order (`int`):
+ the solver order at this step.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ if order == 1:
+ return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_order_update(
+ model_output_list, timestep_list, prev_timestep, sample
+ )
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_order_update(
+ model_output_list, timestep_list, prev_timestep, sample
+ )
+ else:
+ raise ValueError(f"Order must be 1, 2, 3, got {order}")
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the singlestep DPM-Solver.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ order = self.order_list[step_index]
+
+ # For img2img denoising might start with order>1 which is not possible
+ # In this case make sure that the first two steps are both order=1
+ while self.model_outputs[-order] is None:
+ order -= 1
+
+ # For single-step solvers, we use the initial value at each time with order = 1.
+ if order == 1:
+ self.sample = sample
+
+ timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
+ prev_sample = self.singlestep_dpm_solver_update(
+ self.model_outputs, timestep_list, prev_timestep, self.sample, order
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..065f657032e6ef21bd022f938a3b1e7ada334436
--- /dev/null
+++ b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -0,0 +1,358 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, logging, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
+class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.is_scale_input_called = False
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
+ ::-1
+ ].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator (`torch.Generator`, optional): Random number generator.
+ return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
+ a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ elif self.config.prediction_type == "sample":
+ raise NotImplementedError("prediction_type not implemented yet: sample")
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ sigma_from = self.sigmas[step_index]
+ sigma_to = self.sigmas[step_index + 1]
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+
+ dt = sigma_down - sigma
+
+ prev_sample = sample + derivative * dt
+
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
+
+ prev_sample = prev_sample + noise * sigma_up
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerAncestralDiscreteSchedulerOutput(
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
+ )
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_euler_discrete.py b/diffusers/schedulers/scheduling_euler_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb126d4b953cd28e23d048c4f1e2cf8ed90cdac0
--- /dev/null
+++ b/diffusers/schedulers/scheduling_euler_discrete.py
@@ -0,0 +1,432 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, logging, randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
+class EulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
+ k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `"epsilon"`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ interpolation_type (`str`, default `"linear"`, optional):
+ interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
+ [`"linear"`, `"log_linear"`].
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ interpolation_type: str = "linear",
+ use_karras_sigmas: Optional[bool] = False,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.is_scale_input_called = False
+ self.use_karras_sigmas = use_karras_sigmas
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
+ ::-1
+ ].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+
+ if self.config.interpolation_type == "linear":
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ elif self.config.interpolation_type == "log_linear":
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
+ else:
+ raise ValueError(
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
+ " 'linear' or 'log_linear'"
+ )
+
+ if self.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ s_churn (`float`)
+ s_tmin (`float`)
+ s_tmax (`float`)
+ s_noise (`float`)
+ generator (`torch.Generator`, optional): Random number generator.
+ return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
+
+ noise = randn_tensor(
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ )
+
+ eps = noise * s_noise
+ sigma_hat = sigma * (gamma + 1)
+
+ if gamma > 0:
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
+ # backwards compatibility
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma_hat * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma_hat
+
+ dt = self.sigmas[step_index + 1] - sigma_hat
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_heun_discrete.py b/diffusers/schedulers/scheduling_heun_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f694fd60fc9f7f596f0d28d19cc231a26712fd1
--- /dev/null
+++ b/diffusers/schedulers/scheduling_heun_discrete.py
@@ -0,0 +1,426 @@
+# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
+ k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf).
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ use_karras_sigmas: Optional[bool] = False,
+ clip_sample: Optional[bool] = False,
+ clip_sample_range: float = 1.0,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
+ elif beta_schedule == "exp":
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+ self.use_karras_sigmas = use_karras_sigmas
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(self._index_counter) == 0:
+ pos = 1 if len(indices) > 1 else 0
+ else:
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ pos = self._index_counter[timestep_int]
+
+ return indices[pos].item()
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+
+ if self.config.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
+
+ timesteps = torch.from_numpy(timesteps)
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = timesteps.to(device, dtype=torch.float32)
+ else:
+ self.timesteps = timesteps.to(device=device)
+
+ # empty dt and derivative
+ self.prev_derivative = None
+ self.dt = None
+
+ # for exp beta schedules, such as the one for `pipeline_shap_e.py`
+ # we need an index counter
+ self._index_counter = defaultdict(int)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ @property
+ def state_in_first_order(self):
+ return self.dt is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ # advance index counter by 1
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ self._index_counter[timestep_int] += 1
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # 2nd order / Heun's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_next = self.sigmas[step_index]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_next - sigma_hat
+
+ # store for 2nd order step
+ self.prev_derivative = derivative
+ self.dt = dt
+ self.sample = sample
+ else:
+ # 2. 2nd order / Heun's method
+ derivative = (sample - pred_original_sample) / sigma_next
+ derivative = (self.prev_derivative + derivative) / 2
+
+ # 3. take prev timestep & sample
+ dt = self.dt
+ sample = self.sample
+
+ # free dt and derivative
+ # Note, this puts the scheduler in "first order mode"
+ self.prev_derivative = None
+ self.dt = None
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ipndm.py b/diffusers/schedulers/scheduling_ipndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e521590782de6bc14e9b8c29642c7595fafc93
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ipndm.py
@@ -0,0 +1,161 @@
+# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class IPNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion
+ [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296)
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
+ ):
+ # set `betas`, `alphas`, `timesteps`
+ self.set_timesteps(num_train_timesteps)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.ets = []
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
+ steps = torch.cat([steps, torch.tensor([0.0])])
+
+ if self.config.trained_betas is not None:
+ self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
+ else:
+ self.betas = torch.sin(steps * math.pi / 2) ** 2
+
+ self.alphas = (1.0 - self.betas**2) ** 0.5
+
+ timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
+ self.timesteps = timesteps.to(device)
+
+ self.ets = []
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep_index = (self.timesteps == timestep).nonzero().item()
+ prev_timestep_index = timestep_index + 1
+
+ ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
+ self.ets.append(ets)
+
+ if len(self.ets) == 1:
+ ets = self.ets[-1]
+ elif len(self.ets) == 2:
+ ets = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
+ alpha = self.alphas[timestep_index]
+ sigma = self.betas[timestep_index]
+
+ next_alpha = self.alphas[prev_timestep_index]
+ next_sigma = self.betas[prev_timestep_index]
+
+ pred = (sample - sigma * ets) / max(alpha, 1e-8)
+ prev_sample = next_alpha * pred + ets * next_sigma
+
+ return prev_sample
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdf9379b9b90a53e3c8aad20a69e9ab7bffc691e
--- /dev/null
+++ b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -0,0 +1,420 @@
+# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
+ https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
+
+ Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(self._index_counter) == 0:
+ pos = 1 if len(indices) > 1 else 0
+ else:
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ pos = self._index_counter[timestep_int]
+
+ return indices[pos].item()
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ else:
+ sigma = self.sigmas_interpol[step_index - 1]
+
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
+
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+
+ # compute up and down sigmas
+ sigmas_next = sigmas.roll(-1)
+ sigmas_next[-1] = 0.0
+ sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5
+ sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5
+ sigmas_down[-1] = 0.0
+
+ # compute interpolated sigmas
+ sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp()
+ sigmas_interpol[-2:] = 0.0
+
+ # set sigmas
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
+ self.sigmas_interpol = torch.cat(
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
+ )
+ self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
+ self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ timesteps = torch.from_numpy(timesteps).to(device)
+
+ timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
+ interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
+
+ self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
+
+ self.sample = None
+
+ # for exp beta schedules, such as the one for `pipeline_shap_e.py`
+ # we need an index counter
+ self._index_counter = defaultdict(int)
+
+ def sigma_to_t(self, sigma):
+ # get log sigma
+ log_sigma = sigma.log()
+
+ # get distribution
+ dists = log_sigma - self.log_sigmas[:, None]
+
+ # get sigmas range
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = self.log_sigmas[low_idx]
+ high = self.log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.view(sigma.shape)
+ return t
+
+ @property
+ def state_in_first_order(self):
+ return self.sample is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ # advance index counter by 1
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ self._index_counter[timestep_int] += 1
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_interpol = self.sigmas_interpol[step_index]
+ sigma_up = self.sigmas_up[step_index]
+ sigma_down = self.sigmas_down[step_index - 1]
+ else:
+ # 2nd order / KPDM2's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_interpol = self.sigmas_interpol[step_index - 1]
+ sigma_up = self.sigmas_up[step_index - 1]
+ sigma_down = self.sigmas_down[step_index - 1]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ elif self.config.prediction_type == "sample":
+ raise NotImplementedError("prediction_type not implemented yet: sample")
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_interpol - sigma_hat
+
+ # store for 2nd order step
+ self.sample = sample
+ self.dt = dt
+ prev_sample = sample + derivative * dt
+ else:
+ # DPM-Solver-2
+ # 2. Convert to an ODE derivative for 2nd order
+ derivative = (sample - pred_original_sample) / sigma_interpol
+ # 3. delta timestep
+ dt = sigma_down - sigma_hat
+
+ sample = self.sample
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+ prev_sample = prev_sample + noise * sigma_up
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a1b4e6640d1bc10ef6475bde39b5f39a87ec80
--- /dev/null
+++ b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -0,0 +1,401 @@
+# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
+ https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
+
+ Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(self._index_counter) == 0:
+ pos = 1 if len(indices) > 1 else 0
+ else:
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ pos = self._index_counter[timestep_int]
+
+ return indices[pos].item()
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ else:
+ sigma = self.sigmas_interpol[step_index]
+
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
+
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+
+ # interpolate sigmas
+ sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp()
+
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
+ self.sigmas_interpol = torch.cat(
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
+ )
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ timesteps = torch.from_numpy(timesteps).to(device)
+
+ # interpolate timesteps
+ timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
+ interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
+
+ self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
+
+ self.sample = None
+
+ # for exp beta schedules, such as the one for `pipeline_shap_e.py`
+ # we need an index counter
+ self._index_counter = defaultdict(int)
+
+ def sigma_to_t(self, sigma):
+ # get log sigma
+ log_sigma = sigma.log()
+
+ # get distribution
+ dists = log_sigma - self.log_sigmas[:, None]
+
+ # get sigmas range
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = self.log_sigmas[low_idx]
+ high = self.log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.view(sigma.shape)
+ return t
+
+ @property
+ def state_in_first_order(self):
+ return self.sample is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ # advance index counter by 1
+ timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
+ self._index_counter[timestep_int] += 1
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_interpol = self.sigmas_interpol[step_index + 1]
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # 2nd order / KDPM2's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_interpol = self.sigmas_interpol[step_index]
+ sigma_next = self.sigmas[step_index]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ elif self.config.prediction_type == "sample":
+ raise NotImplementedError("prediction_type not implemented yet: sample")
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_interpol - sigma_hat
+
+ # store for 2nd order step
+ self.sample = sample
+ else:
+ # DPM-Solver-2
+ # 2. Convert to an ODE derivative for 2nd order
+ derivative = (sample - pred_original_sample) / sigma_interpol
+
+ # 3. delta timestep
+ dt = sigma_next - sigma_hat
+
+ sample = self.sample
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_karras_ve.py b/diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f6514a4e93e4a75bd6228ed852306b8c005c3d
--- /dev/null
+++ b/diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,232 @@
+# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivative of predicted original image sample (x_0).
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+
+ """
+
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ # setable values
+ self.num_inference_steps: int = None
+ self.timesteps: np.IntTensor = None
+ self.schedule: torch.FloatTensor = None # sigma(t_i)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ schedule = [
+ (
+ self.config.sigma_max**2
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
+ )
+ for i in self.timesteps
+ ]
+ self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
+
+ def add_noise_to_input(
+ self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[torch.FloatTensor, float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.config.s_min <= sigma <= self.config.s_max:
+ gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
+
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def step_correct(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ sample_prev: torch.FloatTensor,
+ derivative: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ sample_prev (`torch.FloatTensor`): TODO
+ derivative (`torch.FloatTensor`): TODO
+ return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/diffusers/schedulers/scheduling_karras_ve_flax.py b/diffusers/schedulers/scheduling_karras_ve_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c0dbddf7efd22df21cc9859e68d62b54aa8609
--- /dev/null
+++ b/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -0,0 +1,237 @@
+# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from jax import random
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin
+
+
+@flax.struct.dataclass
+class KarrasVeSchedulerState:
+ # setable values
+ num_inference_steps: Optional[int] = None
+ timesteps: Optional[jnp.ndarray] = None
+ schedule: Optional[jnp.ndarray] = None # sigma(t_i)
+
+ @classmethod
+ def create(cls):
+ return cls()
+
+
+@dataclass
+class FlaxKarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivative of predicted original image sample (x_0).
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ """
+
+ prev_sample: jnp.ndarray
+ derivative: jnp.ndarray
+ state: KarrasVeSchedulerState
+
+
+class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+ """
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ ):
+ pass
+
+ def create_state(self):
+ return KarrasVeSchedulerState.create()
+
+ def set_timesteps(
+ self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> KarrasVeSchedulerState:
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`KarrasVeSchedulerState`):
+ the `FlaxKarrasVeScheduler` state data class.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
+ schedule = [
+ (
+ self.config.sigma_max**2
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
+ )
+ for i in timesteps
+ ]
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ schedule=jnp.array(schedule, dtype=jnp.float32),
+ timesteps=timesteps,
+ )
+
+ def add_noise_to_input(
+ self,
+ state: KarrasVeSchedulerState,
+ sample: jnp.ndarray,
+ sigma: float,
+ key: random.KeyArray,
+ ) -> Tuple[jnp.ndarray, float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.config.s_min <= sigma <= self.config.s_max:
+ gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ key = random.split(key, num=1)
+ eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ state: KarrasVeSchedulerState,
+ model_output: jnp.ndarray,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
+
+ Returns:
+ [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
+ chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative, state)
+
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
+
+ def step_correct(
+ self,
+ state: KarrasVeSchedulerState,
+ model_output: jnp.ndarray,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: jnp.ndarray,
+ sample_prev: jnp.ndarray,
+ derivative: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative, state)
+
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
+
+ def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/schedulers/scheduling_lms_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..d58d4ce45bd17645b86905c1ae36ce937015fc29
--- /dev/null
+++ b/diffusers/schedulers/scheduling_lms_discrete.py
@@ -0,0 +1,413 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
+class LMSDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ use_karras_sigmas: Optional[bool] = False,
+ prediction_type: str = "epsilon",
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # setable values
+ self.num_inference_steps = None
+ self.use_karras_sigmas = use_karras_sigmas
+ self.set_timesteps(num_train_timesteps, None)
+ self.derivatives = []
+ self.is_scale_input_called = False
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return self.sigmas.max()
+
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
+ ::-1
+ ].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+
+ if self.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ self.derivatives = []
+
+ # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(sigma)
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, self.num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+
+ """
+ if not self.is_scale_input_called:
+ warnings.warn(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(step_index + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_lms_discrete_flax.py b/diffusers/schedulers/scheduling_lms_discrete_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..f96e602afe121a09876b0ff7db1d3192e441e32a
--- /dev/null
+++ b/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -0,0 +1,283 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ CommonSchedulerState,
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+@flax.struct.dataclass
+class LMSDiscreteSchedulerState:
+ common: CommonSchedulerState
+
+ # setable values
+ init_noise_sigma: jnp.ndarray
+ timesteps: jnp.ndarray
+ sigmas: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ # running values
+ derivatives: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(
+ cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
+ ):
+ return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
+
+
+@dataclass
+class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
+ state: LMSDiscreteSchedulerState
+
+
+class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ the `dtype` used for params and computation.
+ """
+
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
+
+ dtype: jnp.dtype
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ prediction_type: str = "epsilon",
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
+ sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = sigmas.max()
+
+ return LMSDiscreteSchedulerState.create(
+ common=common,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ sigmas=sigmas,
+ )
+
+ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
+
+ Args:
+ state (`LMSDiscreteSchedulerState`):
+ the `FlaxLMSDiscreteScheduler` state data class instance.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ timestep (`int`):
+ current discrete timestep in the diffusion chain.
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
+ step_index = step_index[0]
+
+ sigma = state.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(
+ self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> LMSDiscreteSchedulerState:
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`LMSDiscreteSchedulerState`):
+ the `FlaxLMSDiscreteScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
+
+ low_idx = jnp.floor(timesteps).astype(jnp.int32)
+ high_idx = jnp.ceil(timesteps).astype(jnp.int32)
+
+ frac = jnp.mod(timesteps, 1.0)
+
+ sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
+
+ timesteps = timesteps.astype(jnp.int32)
+
+ # initial running values
+ derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
+
+ return state.replace(
+ timesteps=timesteps,
+ sigmas=sigmas,
+ num_inference_steps=num_inference_steps,
+ derivatives=derivatives,
+ )
+
+ def step(
+ self,
+ state: LMSDiscreteSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
+
+ Returns:
+ [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ sigma = state.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
+ if len(state.derivatives) > order:
+ state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def add_noise(
+ self,
+ state: LMSDiscreteSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sigma = state.sigmas[timesteps].flatten()
+ sigma = broadcast_to_shape_from_left(sigma, noise.shape)
+
+ noisy_samples = original_samples + noise * sigma
+
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_pndm.py b/diffusers/schedulers/scheduling_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..794eb3674c1bb5533b938b00b08d48cd5192c317
--- /dev/null
+++ b/diffusers/schedulers/scheduling_pndm.py
@@ -0,0 +1,462 @@
+# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class PNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+ set_alpha_to_one (`bool`, default `False`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process)
+ or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ timestep_spacing (`str`, default `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ prediction_type: str = "epsilon",
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ self.num_inference_steps = num_inference_steps
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ self._timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
+ self._timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(
+ np.int64
+ )
+ self._timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.ets = []
+ self.counter = 0
+ self.cur_model_output = 0
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = timestep - diff_to_prev
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ if self.counter != 1:
+ self.ets = self.ets[-3:]
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if self.config.prediction_type == "v_prediction":
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ elif self.config.prediction_type != "epsilon":
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
+ )
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_pndm_flax.py b/diffusers/schedulers/scheduling_pndm_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..c654f2de8dd3e4f96403cce4b9db8f8b7b69861f
--- /dev/null
+++ b/diffusers/schedulers/scheduling_pndm_flax.py
@@ -0,0 +1,511 @@
+# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ CommonSchedulerState,
+ FlaxKarrasDiffusionSchedulers,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ add_noise_common,
+)
+
+
+@flax.struct.dataclass
+class PNDMSchedulerState:
+ common: CommonSchedulerState
+ final_alpha_cumprod: jnp.ndarray
+
+ # setable values
+ init_noise_sigma: jnp.ndarray
+ timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+ prk_timesteps: Optional[jnp.ndarray] = None
+ plms_timesteps: Optional[jnp.ndarray] = None
+
+ # running values
+ cur_model_output: Optional[jnp.ndarray] = None
+ counter: Optional[jnp.int32] = None
+ cur_sample: Optional[jnp.ndarray] = None
+ ets: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(
+ cls,
+ common: CommonSchedulerState,
+ final_alpha_cumprod: jnp.ndarray,
+ init_noise_sigma: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ):
+ return cls(
+ common=common,
+ final_alpha_cumprod=final_alpha_cumprod,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+
+@dataclass
+class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
+ state: PNDMSchedulerState
+
+
+class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+ set_alpha_to_one (`bool`, default `False`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
+ the `dtype` used for params and computation.
+ """
+
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
+
+ dtype: jnp.dtype
+ pndm_order: int
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ self.dtype = dtype
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
+ if common is None:
+ common = CommonSchedulerState.create(self)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ final_alpha_cumprod = (
+ jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
+ )
+
+ # standard deviation of the initial noise distribution
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
+
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
+
+ return PNDMSchedulerState.create(
+ common=common,
+ final_alpha_cumprod=final_alpha_cumprod,
+ init_noise_sigma=init_noise_sigma,
+ timesteps=timesteps,
+ )
+
+ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`PNDMSchedulerState`):
+ the `FlaxPNDMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ shape (`Tuple`):
+ the shape of the samples to be generated.
+ """
+
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # rounding to avoid issues when num_inference_step is power of 3
+ _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+
+ prk_timesteps = jnp.array([], dtype=jnp.int32)
+ plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
+
+ else:
+ prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
+ jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
+ self.pndm_order,
+ )
+
+ prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
+ plms_timesteps = _timesteps[:-3][::-1]
+
+ timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
+
+ # initial running values
+
+ cur_model_output = jnp.zeros(shape, dtype=self.dtype)
+ counter = jnp.int32(0)
+ cur_sample = jnp.zeros(shape, dtype=self.dtype)
+ ets = jnp.zeros((4,) + shape, dtype=self.dtype)
+
+ return state.replace(
+ timesteps=timesteps,
+ num_inference_steps=num_inference_steps,
+ prk_timesteps=prk_timesteps,
+ plms_timesteps=plms_timesteps,
+ cur_model_output=cur_model_output,
+ counter=counter,
+ cur_sample=cur_sample,
+ ets=ets,
+ )
+
+ def scale_model_input(
+ self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def step(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.config.skip_prk_steps:
+ prev_sample, state = self.step_plms(state, model_output, timestep, sample)
+ else:
+ prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample)
+ plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample)
+
+ cond = state.counter < len(state.prk_timesteps)
+
+ prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample)
+
+ state = state.replace(
+ cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output),
+ ets=jax.lax.select(cond, prk_state.ets, plms_state.ets),
+ cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample),
+ counter=jax.lax.select(cond, prk_state.counter, plms_state.counter),
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def step_prk(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = jnp.where(
+ state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
+ )
+ prev_timestep = timestep - diff_to_prev
+ timestep = state.prk_timesteps[state.counter // 4 * 4]
+
+ model_output = jax.lax.select(
+ (state.counter % 4) != 3,
+ model_output, # remainder 0, 1, 2
+ state.cur_model_output + 1 / 6 * model_output, # remainder 3
+ )
+
+ state = state.replace(
+ cur_model_output=jax.lax.select_n(
+ state.counter % 4,
+ state.cur_model_output + 1 / 6 * model_output, # remainder 0
+ state.cur_model_output + 1 / 3 * model_output, # remainder 1
+ state.cur_model_output + 1 / 3 * model_output, # remainder 2
+ jnp.zeros_like(state.cur_model_output), # remainder 3
+ ),
+ ets=jax.lax.select(
+ (state.counter % 4) == 0,
+ state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0
+ state.ets, # remainder 1, 2, 3
+ ),
+ cur_sample=jax.lax.select(
+ (state.counter % 4) == 0,
+ sample, # remainder 0
+ state.cur_sample, # remainder 1, 2, 3
+ ),
+ )
+
+ cur_sample = state.cur_sample
+ prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output)
+ state = state.replace(counter=state.counter + 1)
+
+ return (prev_sample, state)
+
+ def step_plms(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before
+
+ prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+ prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
+
+ # Reference:
+ # if state.counter != 1:
+ # state.ets.append(model_output)
+ # else:
+ # prev_timestep = timestep
+ # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
+
+ prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
+ timestep = jnp.where(
+ state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
+ )
+
+ # Reference:
+ # if len(state.ets) == 1 and state.counter == 0:
+ # model_output = model_output
+ # state.cur_sample = sample
+ # elif len(state.ets) == 1 and state.counter == 1:
+ # model_output = (model_output + state.ets[-1]) / 2
+ # sample = state.cur_sample
+ # state.cur_sample = None
+ # elif len(state.ets) == 2:
+ # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
+ # elif len(state.ets) == 3:
+ # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
+ # else:
+ # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
+
+ state = state.replace(
+ ets=jax.lax.select(
+ state.counter != 1,
+ state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1
+ state.ets, # counter 1
+ ),
+ cur_sample=jax.lax.select(
+ state.counter != 1,
+ sample, # counter != 1
+ state.cur_sample, # counter 1
+ ),
+ )
+
+ state = state.replace(
+ cur_model_output=jax.lax.select_n(
+ jnp.clip(state.counter, 0, 4),
+ model_output, # counter 0
+ (model_output + state.ets[-1]) / 2, # counter 1
+ (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2
+ (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3
+ (1 / 24)
+ * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4
+ ),
+ )
+
+ sample = state.cur_sample
+ model_output = state.cur_model_output
+ prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)
+ state = state.replace(counter=state.counter + 1)
+
+ return (prev_sample, state)
+
+ def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = state.common.alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(
+ prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
+ )
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if self.config.prediction_type == "v_prediction":
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ elif self.config.prediction_type != "epsilon":
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
+ )
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ state: PNDMSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ return add_noise_common(state.common, original_samples, noise, timesteps)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_repaint.py b/diffusers/schedulers/scheduling_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e7450d2df68c40c3b4f49669513832e443c5e3
--- /dev/null
+++ b/diffusers/schedulers/scheduling_repaint.py
@@ -0,0 +1,344 @@
+# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class RePaintSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from
+ the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: torch.FloatTensor
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class RePaintScheduler(SchedulerMixin, ConfigMixin):
+ """
+ RePaint is a schedule for DDPM inpainting inside a given mask.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
+ eta (`float`):
+ The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and
+ 1.0 is DDPM scheduler respectively.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ eta: float = 0.0,
+ trained_betas: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.from_numpy(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ self.final_alpha_cumprod = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.eta = eta
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ jump_length: int = 10,
+ jump_n_sample: int = 10,
+ device: Union[str, torch.device] = None,
+ ):
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+
+ timesteps = []
+
+ jumps = {}
+ for j in range(0, num_inference_steps - jump_length, jump_length):
+ jumps[j] = jump_n_sample - 1
+
+ t = num_inference_steps
+ while t >= 1:
+ t = t - 1
+ timesteps.append(t)
+
+ if jumps.get(t, 0) > 0:
+ jumps[t] = jumps[t] - 1
+ for _ in range(jump_length):
+ t = t + 1
+ timesteps.append(t)
+
+ timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_variance(self, t):
+ prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from
+ # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
+ # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
+ # variance to pred_sample
+ # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
+ # without eta.
+ # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ original_image: torch.FloatTensor,
+ mask: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[RePaintSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned
+ diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ original_image (`torch.FloatTensor`):
+ the original image to inpaint on.
+ mask (`torch.FloatTensor`):
+ the mask where 0.0 values define which part of the original image to inpaint (change).
+ generator (`torch.Generator`, *optional*): random number generator.
+ return_dict (`bool`): option for returning tuple rather than
+ DDPMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
+ # substitute formula (7) in the algorithm coming from DDPM paper
+ # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper.
+ # DDIM schedule gives the same results as DDPM with eta = 1.0
+ # Noise is being reused in 7. and 8., but no impact on quality has
+ # been observed.
+
+ # 5. Add noise
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
+ std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
+
+ variance = 0
+ if t > 0 and self.eta > 0:
+ variance = std_dev_t * noise
+
+ # 6. compute "direction pointing to x_t" of formula (12)
+ # from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
+
+ # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
+
+ # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
+ prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
+
+ # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
+ pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
+
+ if not return_dict:
+ return (
+ pred_prev_sample,
+ pred_original_sample,
+ )
+
+ return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ def undo_step(self, sample, timestep, generator=None):
+ n = self.config.num_train_timesteps // self.num_inference_steps
+
+ for i in range(n):
+ beta = self.betas[timestep + i]
+ if sample.device.type == "mps":
+ # randn does not work reproducibly on mps
+ noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator)
+ noise = noise.to(sample.device)
+ else:
+ noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
+
+ # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
+ sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
+
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/schedulers/scheduling_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..339edfbb02eb6ac0f79b3969004418bb29e212b5
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_ve.py
@@ -0,0 +1,288 @@
+# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+@dataclass
+class SdeVeOutput(BaseOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ prev_sample: torch.FloatTensor
+ prev_sample_mean: torch.FloatTensor
+
+
+class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ # setable values
+ self.timesteps = None
+
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
+ ):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional):
+ final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
+
+ def set_sigmas(
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
+ ):
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional):
+ final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional):
+ final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if self.timesteps is None:
+ self.set_timesteps(num_inference_steps, sampling_eps)
+
+ self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
+ self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+
+ def get_adjacent_sigma(self, timesteps, t):
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
+ )
+
+ def step_pred(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * torch.ones(
+ sample.shape[0], device=sample.device
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
+
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.discrete_sigmas.device)
+
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
+ drift = torch.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ diffusion = diffusion.flatten()
+ while len(diffusion.shape) < len(sample.shape):
+ diffusion = diffusion.unsqueeze(-1)
+ drift = drift - diffusion**2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ noise = randn_tensor(
+ sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
+ )
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean)
+
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
+
+ def step_correct(
+ self,
+ model_output: torch.FloatTensor,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
+ # self.repeat_scalar(step_size, sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ step_size = step_size.flatten()
+ while len(step_size.shape) < len(sample.shape):
+ step_size = step_size.unsqueeze(-1)
+ prev_sample_mean = sample + step_size * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ timesteps = timesteps.to(original_samples.device)
+ sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
+ noise = (
+ noise * sigmas[:, None, None, None]
+ if noise is not None
+ else torch.randn_like(original_samples) * sigmas[:, None, None, None]
+ )
+ noisy_samples = noise + original_samples
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_ve_flax.py b/diffusers/schedulers/scheduling_sde_ve_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6240559fc88fa45e4612dc3005ba66e10d3269d
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -0,0 +1,279 @@
+# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from jax import random
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+
+
+@flax.struct.dataclass
+class ScoreSdeVeSchedulerState:
+ # setable values
+ timesteps: Optional[jnp.ndarray] = None
+ discrete_sigmas: Optional[jnp.ndarray] = None
+ sigmas: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(cls):
+ return cls()
+
+
+@dataclass
+class FlaxSdeVeOutput(FlaxSchedulerOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`):
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ state: ScoreSdeVeSchedulerState
+ prev_sample: jnp.ndarray
+ prev_sample_mean: Optional[jnp.ndarray] = None
+
+
+class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ """
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ ):
+ pass
+
+ def create_state(self):
+ state = ScoreSdeVeSchedulerState.create()
+ return self.set_sigmas(
+ state,
+ self.config.num_train_timesteps,
+ self.config.sigma_min,
+ self.config.sigma_max,
+ self.config.sampling_eps,
+ )
+
+ def set_timesteps(
+ self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
+ ) -> ScoreSdeVeSchedulerState:
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional):
+ final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+
+ timesteps = jnp.linspace(1, sampling_eps, num_inference_steps)
+ return state.replace(timesteps=timesteps)
+
+ def set_sigmas(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ num_inference_steps: int,
+ sigma_min: float = None,
+ sigma_max: float = None,
+ sampling_eps: float = None,
+ ) -> ScoreSdeVeSchedulerState:
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional):
+ final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional):
+ final timestep value (overrides value given at Scheduler instantiation).
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if state.timesteps is None:
+ state = self.set_timesteps(state, num_inference_steps, sampling_eps)
+
+ discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps))
+ sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps])
+
+ return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas)
+
+ def get_adjacent_sigma(self, state, timesteps, t):
+ return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1])
+
+ def step_pred(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ key: random.KeyArray,
+ return_dict: bool = True,
+ ) -> Union[FlaxSdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
+
+ Returns:
+ [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.timesteps is None:
+ raise ValueError(
+ "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * jnp.ones(
+ sample.shape[0],
+ )
+ timesteps = (timestep * (len(state.timesteps) - 1)).long()
+
+ sigma = state.discrete_sigmas[timesteps]
+ adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep)
+ drift = jnp.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ diffusion = diffusion.flatten()
+ diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
+ drift = drift - diffusion**2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ key = random.split(key, num=1)
+ noise = random.normal(key=key, shape=sample.shape)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean, state)
+
+ return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state)
+
+ def step_correct(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ model_output: jnp.ndarray,
+ sample: jnp.ndarray,
+ key: random.KeyArray,
+ return_dict: bool = True,
+ ) -> Union[FlaxSdeVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
+
+ Returns:
+ [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.timesteps is None:
+ raise ValueError(
+ "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ key = random.split(key, num=1)
+ noise = random.normal(key=key, shape=sample.shape)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = jnp.linalg.norm(model_output)
+ noise_norm = jnp.linalg.norm(noise)
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * jnp.ones(sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ step_size = step_size.flatten()
+ step_size = broadcast_to_shape_from_left(step_size, sample.shape)
+ prev_sample_mean = sample + step_size * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxSdeVeOutput(prev_sample=prev_sample, state=state)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_vp.py b/diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e2ead90edb57cd1eb1d270695e222d404064180
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,90 @@
+# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import math
+from typing import Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance preserving stochastic differential equation (SDE) scheduler.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ UNDER CONSTRUCTION
+
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
+
+ def step_pred(self, score, x, t, generator=None):
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ std = std.flatten()
+ while len(std.shape) < len(score.shape):
+ std = std.unsqueeze(-1)
+ score = -score / std
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ beta_t = beta_t.flatten()
+ while len(beta_t.shape) < len(x.shape):
+ beta_t = beta_t.unsqueeze(-1)
+ drift = -0.5 * beta_t * x
+
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion**2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
+ x = x_mean + diffusion * math.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_unclip.py b/diffusers/schedulers/scheduling_unclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd23e48bad00d16a1086f31b6584ff9df03129fb
--- /dev/null
+++ b/diffusers/schedulers/scheduling_unclip.py
@@ -0,0 +1,348 @@
+# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP
+class UnCLIPSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
+ """
+ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This
+ scheduler will be removed and replaced with DDPM.
+
+ This is a modified DDPM Scheduler specifically for the karlo unCLIP model.
+
+ This scheduler has some minor variations in how it calculates the learned range variance and dynamically
+ re-calculates betas based off the timesteps it is skipping.
+
+ The scheduler also uses a slightly different step ratio when computing timesteps to use for inference.
+
+ See [`~DDPMScheduler`] for more information on DDPM scheduling
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log`
+ or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical
+ stability.
+ clip_sample_range (`float`, default `1.0`):
+ The range to clip the sample between. See `clip_sample`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process)
+ or `sample` (directly predicting the noisy sample`)
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ variance_type: str = "fixed_small_log",
+ clip_sample: bool = True,
+ clip_sample_range: Optional[float] = 1.0,
+ prediction_type: str = "epsilon",
+ beta_schedule: str = "squaredcos_cap_v2",
+ ):
+ if beta_schedule != "squaredcos_cap_v2":
+ raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'")
+
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.variance_type = variance_type
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The
+ different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy
+ of the results.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1)
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None):
+ if prev_timestep is None:
+ prev_timestep = t - 1
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if prev_timestep == t - 1:
+ beta = self.betas[t]
+ else:
+ beta = 1 - alpha_prod_t / alpha_prod_t_prev
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = beta_prod_t_prev / beta_prod_t * beta
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small_log":
+ variance = torch.log(torch.clamp(variance, min=1e-20))
+ variance = torch.exp(0.5 * variance)
+ elif variance_type == "learned_range":
+ # NOTE difference with DDPM scheduler
+ min_log = variance.log()
+ max_log = beta.log()
+
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ prev_timestep: Optional[int] = None,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[UnCLIPSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
+ Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ if prev_timestep is None:
+ prev_timestep = t - 1
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if prev_timestep == t - 1:
+ beta = self.betas[t]
+ alpha = self.alphas[t]
+ else:
+ beta = 1 - alpha_prod_t / alpha_prod_t_prev
+ alpha = 1 - beta
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`"
+ " for the UnCLIPScheduler."
+ )
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(
+ pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t
+ current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ variance_noise = randn_tensor(
+ model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
+ )
+
+ variance = self._get_variance(
+ t,
+ predicted_variance=predicted_variance,
+ prev_timestep=prev_timestep,
+ )
+
+ if self.variance_type == "fixed_small_log":
+ variance = variance
+ elif self.variance_type == "learned_range":
+ variance = (0.5 * variance).exp()
+ else:
+ raise ValueError(
+ f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`"
+ " for the UnCLIPScheduler."
+ )
+
+ variance = variance * variance_noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
diff --git a/diffusers/schedulers/scheduling_unipc_multistep.py b/diffusers/schedulers/scheduling_unipc_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..3caa01a58562f5f12d46354ef6112a64875da79d
--- /dev/null
+++ b/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -0,0 +1,681 @@
+# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info
+# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a
+ corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is
+ by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can
+ also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied
+ after any off-the-shelf solvers to increase the order of accuracy.
+
+ For more details, see the original paper: https://arxiv.org/abs/2302.04867
+
+ Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend
+ to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note
+ that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
+ accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling,
+ and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the
+ dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models
+ (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
+ predict_x0 (`bool`, default `True`):
+ whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
+ solver_type (`str`, default `bh2`):
+ the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2`
+ otherwise.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+ disable_corrector (`list`, default `[]`):
+ decide which step to disable the corrector. For large guidance scale, the misalignment between the
+ `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated
+ by disable the corrector at the first few steps (e.g., disable_corrector=[0])
+ solver_p (`SchedulerMixin`, default `None`):
+ can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
+ timestep_spacing (`str`, default `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: List[int] = [],
+ solver_p: SchedulerMixin = None,
+ use_karras_sigmas: Optional[bool] = False,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ if solver_type not in ["bh1", "bh2"]:
+ if solver_type in ["midpoint", "heun", "logrho"]:
+ self.register_to_config(solver_type="bh2")
+ else:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ self.predict_x0 = predict_x0
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.timestep_list = [None] * solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = disable_corrector
+ self.solver_p = solver_p
+ self.last_sample = None
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ if self.config.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
+
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # when num_inference_steps == num_train_timesteps, we can end up with
+ # duplicates in timesteps.
+ _, unique_indices = np.unique(timesteps, return_index=True)
+ timesteps = timesteps[np.sort(unique_indices)]
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.num_inference_steps = len(timesteps)
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+ self.last_sample = None
+ if self.solver_p:
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ r"""
+ Convert the model output to the corresponding type that the algorithm PC needs.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+ if self.predict_x0:
+ if self.config.prediction_type == "epsilon":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the UniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ else:
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the UniPCMultistepScheduler."
+ )
+
+ def multistep_uni_p_bh_update(
+ self,
+ model_output: torch.FloatTensor,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ order: int,
+ ) -> torch.FloatTensor:
+ """
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ direct outputs from learned diffusion model at the current timestep.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ order (`int`): the order of UniP at this step, also the p in UniPC-p.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ timestep_list = self.timestep_list
+ model_output_list = self.model_outputs
+
+ s0, t = self.timestep_list[-1], prev_timestep
+ m0 = model_output_list[-1]
+ x = sample
+
+ if self.solver_p:
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
+ return x_t
+
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+
+ h = lambda_t - lambda_s0
+ device = sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = timestep_list[-(i + 1)]
+ mi = model_output_list[-(i + 1)]
+ lambda_si = self.lambda_t[si]
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ else:
+ D1s = None
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def multistep_uni_c_bh_update(
+ self,
+ this_model_output: torch.FloatTensor,
+ this_timestep: int,
+ last_sample: torch.FloatTensor,
+ this_sample: torch.FloatTensor,
+ order: int,
+ ) -> torch.FloatTensor:
+ """
+ One step for the UniC (B(h) version).
+
+ Args:
+ this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
+ this_timestep (`int`): the current timestep `t`
+ last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
+ this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
+ order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
+ should be order + 1
+
+ Returns:
+ `torch.FloatTensor`: the corrected sample tensor at the current timestep.
+ """
+ timestep_list = self.timestep_list
+ model_output_list = self.model_outputs
+
+ s0, t = timestep_list[-1], this_timestep
+ m0 = model_output_list[-1]
+ x = last_sample
+ x_t = this_sample
+ model_t = this_model_output
+
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+
+ h = lambda_t - lambda_s0
+ device = this_sample.device
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ si = timestep_list[-(i + 1)]
+ mi = model_output_list[-(i + 1)]
+ lambda_si = self.lambda_t[si]
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ else:
+ D1s = None
+
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_c = torch.linalg.solve(R, b)
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the multistep UniPC.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+
+ use_corrector = (
+ step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None
+ )
+
+ model_output_convert = self.convert_model_output(model_output, timestep, sample)
+ if use_corrector:
+ sample = self.multistep_uni_c_bh_update(
+ this_model_output=model_output_convert,
+ this_timestep=timestep,
+ last_sample=self.last_sample,
+ this_sample=sample,
+ order=self.this_order,
+ )
+
+ # now prepare to run the predictor
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.timestep_list[i] = self.timestep_list[i + 1]
+
+ self.model_outputs[-1] = model_output_convert
+ self.timestep_list[-1] = timestep
+
+ if self.config.lower_order_final:
+ this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
+ else:
+ this_order = self.config.solver_order
+
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
+ assert self.this_order > 0
+
+ self.last_sample = sample
+ prev_sample = self.multistep_uni_p_bh_update(
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
+ prev_timestep=prev_timestep,
+ sample=sample,
+ order=self.this_order,
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_utils.py b/diffusers/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f95beb022ac042b6e1ef588a72365b2623338de
--- /dev/null
+++ b/diffusers/schedulers/scheduling_utils.py
@@ -0,0 +1,177 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, Optional, Union
+
+import torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+# NOTE: We make this type an enum because it simplifies usage in docs and prevents
+# circular imports when used for `_compatibles` within the schedulers module.
+# When it's used as a type in pipelines, it really is a Union because the actual
+# scheduler instance is passed in.
+class KarrasDiffusionSchedulers(Enum):
+ DDIMScheduler = 1
+ DDPMScheduler = 2
+ PNDMScheduler = 3
+ LMSDiscreteScheduler = 4
+ EulerDiscreteScheduler = 5
+ HeunDiscreteScheduler = 6
+ EulerAncestralDiscreteScheduler = 7
+ DPMSolverMultistepScheduler = 8
+ DPMSolverSinglestepScheduler = 9
+ KDPM2DiscreteScheduler = 10
+ KDPM2AncestralDiscreteScheduler = 11
+ DEISMultistepScheduler = 12
+ UniPCMultistepScheduler = 13
+ DPMSolverSDEScheduler = 14
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class SchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing the schedluer configurations saved using
+ [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs, commit_hash = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ **kwargs,
+ )
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~SchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
diff --git a/diffusers/schedulers/scheduling_utils_flax.py b/diffusers/schedulers/scheduling_utils_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ce5b8360b9be5bb4b4ec46fbeac0715d6b5869
--- /dev/null
+++ b/diffusers/schedulers/scheduling_utils_flax.py
@@ -0,0 +1,284 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import math
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+# NOTE: We make this type an enum because it simplifies usage in docs and prevents
+# circular imports when used for `_compatibles` within the schedulers module.
+# When it's used as a type in pipelines, it really is a Union because the actual
+# scheduler instance is passed in.
+class FlaxKarrasDiffusionSchedulers(Enum):
+ FlaxDDIMScheduler = 1
+ FlaxDDPMScheduler = 2
+ FlaxPNDMScheduler = 3
+ FlaxLMSDiscreteScheduler = 4
+ FlaxDPMSolverMultistepScheduler = 5
+
+
+@dataclass
+class FlaxSchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: jnp.ndarray
+
+
+class FlaxSchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config = ["dtype"]
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
+ e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+ scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
+
+ if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
+ state = scheduler.create_state()
+
+ if return_unused_kwargs:
+ return scheduler, state, unused_kwargs
+
+ return scheduler, state
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~FlaxSchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
+
+
+def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
+ assert len(shape) >= x.ndim
+ return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=dtype)
+
+
+@flax.struct.dataclass
+class CommonSchedulerState:
+ alphas: jnp.ndarray
+ betas: jnp.ndarray
+ alphas_cumprod: jnp.ndarray
+
+ @classmethod
+ def create(cls, scheduler):
+ config = scheduler.config
+
+ if config.trained_betas is not None:
+ betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
+ elif config.beta_schedule == "linear":
+ betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
+ elif config.beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ betas = (
+ jnp.linspace(
+ config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
+ )
+ ** 2
+ )
+ elif config.beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
+ else:
+ raise NotImplementedError(
+ f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
+ )
+
+ alphas = 1.0 - betas
+
+ alphas_cumprod = jnp.cumprod(alphas, axis=0)
+
+ return cls(
+ alphas=alphas,
+ betas=betas,
+ alphas_cumprod=alphas_cumprod,
+ )
+
+
+def get_sqrt_alpha_prod(
+ state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
+):
+ alphas_cumprod = state.alphas_cumprod
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
+
+
+def add_noise_common(
+ state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
+):
+ sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+
+def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
+ sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
diff --git a/diffusers/schedulers/scheduling_vq_diffusion.py b/diffusers/schedulers/scheduling_vq_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..b92722e4d462ca675bbf11230c1c39810de48b6e
--- /dev/null
+++ b/diffusers/schedulers/scheduling_vq_diffusion.py
@@ -0,0 +1,496 @@
+# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class VQDiffusionSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.LongTensor
+
+
+def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
+ """
+ Convert batch of vector of class indices into batch of log onehot vectors
+
+ Args:
+ x (`torch.LongTensor` of shape `(batch size, vector length)`):
+ Batch of class indices
+
+ num_classes (`int`):
+ number of classes to be used for the onehot vectors
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
+ Log onehot vectors
+ """
+ x_onehot = F.one_hot(x, num_classes)
+ x_onehot = x_onehot.permute(0, 2, 1)
+ log_x = torch.log(x_onehot.float().clamp(min=1e-30))
+ return log_x
+
+
+def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
+ """
+ Apply gumbel noise to `logits`
+ """
+ uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
+ gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
+ noised = gumbel_noise + logits
+ return noised
+
+
+def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
+ """
+ Cumulative and non-cumulative alpha schedules.
+
+ See section 4.1.
+ """
+ att = (
+ np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
+ + alpha_cum_start
+ )
+ att = np.concatenate(([1], att))
+ at = att[1:] / att[:-1]
+ att = np.concatenate((att[1:], [1]))
+ return at, att
+
+
+def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
+ """
+ Cumulative and non-cumulative gamma schedules.
+
+ See section 4.1.
+ """
+ ctt = (
+ np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
+ + gamma_cum_start
+ )
+ ctt = np.concatenate(([0], ctt))
+ one_minus_ctt = 1 - ctt
+ one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
+ ct = 1 - one_minus_ct
+ ctt = np.concatenate((ctt[1:], [0]))
+ return ct, ctt
+
+
+class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
+
+ The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
+ diffusion timestep.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2111.14822
+
+ Args:
+ num_vec_classes (`int`):
+ The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
+ latent pixel.
+
+ num_train_timesteps (`int`):
+ Number of diffusion steps used to train the model.
+
+ alpha_cum_start (`float`):
+ The starting cumulative alpha value.
+
+ alpha_cum_end (`float`):
+ The ending cumulative alpha value.
+
+ gamma_cum_start (`float`):
+ The starting cumulative gamma value.
+
+ gamma_cum_end (`float`):
+ The ending cumulative gamma value.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_vec_classes: int,
+ num_train_timesteps: int = 100,
+ alpha_cum_start: float = 0.99999,
+ alpha_cum_end: float = 0.000009,
+ gamma_cum_start: float = 0.000009,
+ gamma_cum_end: float = 0.99999,
+ ):
+ self.num_embed = num_vec_classes
+
+ # By convention, the index for the mask class is the last class index
+ self.mask_class = self.num_embed - 1
+
+ at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
+ ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
+
+ num_non_mask_classes = self.num_embed - 1
+ bt = (1 - at - ct) / num_non_mask_classes
+ btt = (1 - att - ctt) / num_non_mask_classes
+
+ at = torch.tensor(at.astype("float64"))
+ bt = torch.tensor(bt.astype("float64"))
+ ct = torch.tensor(ct.astype("float64"))
+ log_at = torch.log(at)
+ log_bt = torch.log(bt)
+ log_ct = torch.log(ct)
+
+ att = torch.tensor(att.astype("float64"))
+ btt = torch.tensor(btt.astype("float64"))
+ ctt = torch.tensor(ctt.astype("float64"))
+ log_cumprod_at = torch.log(att)
+ log_cumprod_bt = torch.log(btt)
+ log_cumprod_ct = torch.log(ctt)
+
+ self.log_at = log_at.float()
+ self.log_bt = log_bt.float()
+ self.log_ct = log_ct.float()
+ self.log_cumprod_at = log_cumprod_at.float()
+ self.log_cumprod_bt = log_cumprod_bt.float()
+ self.log_cumprod_ct = log_cumprod_ct.float()
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ device (`str` or `torch.device`):
+ device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.log_at = self.log_at.to(device)
+ self.log_bt = self.log_bt.to(device)
+ self.log_ct = self.log_ct.to(device)
+ self.log_cumprod_at = self.log_cumprod_at.to(device)
+ self.log_cumprod_bt = self.log_cumprod_bt.to(device)
+ self.log_cumprod_ct = self.log_cumprod_ct.to(device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: torch.long,
+ sample: torch.LongTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[VQDiffusionSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
+ docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
+
+ Args:
+ log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
+ The log probabilities for the predicted classes of the initial latent pixels. Does not include a
+ prediction for the masked class as the initial unnoised image cannot be masked.
+
+ t (`torch.long`):
+ The timestep that determines which transition matrices are used.
+
+ x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`
+
+ generator: (`torch.Generator` or None):
+ RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
+
+ return_dict (`bool`):
+ option for returning tuple rather than VQDiffusionSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+ """
+ if timestep == 0:
+ log_p_x_t_min_1 = model_output
+ else:
+ log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
+
+ log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
+
+ x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
+
+ if not return_dict:
+ return (x_t_min_1,)
+
+ return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
+
+ def q_posterior(self, log_p_x_0, x_t, t):
+ """
+ Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
+
+ Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
+ forward probabilities.
+
+ Equation (11) stated in terms of forward probabilities via Equation (5):
+
+ Where:
+ - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
+
+ p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
+
+ Args:
+ log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
+ The log probabilities for the predicted classes of the initial latent pixels. Does not include a
+ prediction for the masked class as the initial unnoised image cannot be masked.
+
+ x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`
+
+ t (torch.Long):
+ The timestep that determines which transition matrix is used.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
+ The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
+ """
+ log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
+
+ log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
+ t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
+ )
+
+ log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
+ t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
+ )
+
+ # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
+ # . . .
+ # . . .
+ # . . .
+ # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
+ q = log_p_x_0 - log_q_x_t_given_x_0
+
+ # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
+ # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
+ q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
+
+ # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
+ # . . .
+ # . . .
+ # . . .
+ # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
+ q = q - q_log_sum_exp
+
+ # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
+ # . . .
+ # . . .
+ # . . .
+ # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
+ # c_cumulative_{t-1} ... c_cumulative_{t-1}
+ q = self.apply_cumulative_transitions(q, t - 1)
+
+ # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
+ # . . .
+ # . . .
+ # . . .
+ # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
+ # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
+ log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
+
+ # For each column, there are two possible cases.
+ #
+ # Where:
+ # - sum(p_n(x_0))) is summing over all classes for x_0
+ # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
+ # - C_j is the class transitioning to
+ #
+ # 1. x_t is masked i.e. x_t = c_k
+ #
+ # Simplifying the expression, the column vector is:
+ # .
+ # .
+ # .
+ # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
+ # .
+ # .
+ # .
+ # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
+ #
+ # From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
+ #
+ # For the other rows, we can state the equation as ...
+ #
+ # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
+ #
+ # This verifies the other rows.
+ #
+ # 2. x_t is not masked
+ #
+ # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
+ # .
+ # .
+ # .
+ # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
+ # .
+ # .
+ # .
+ # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
+ # .
+ # .
+ # .
+ # 0
+ #
+ # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
+ return log_p_x_t_min_1
+
+ def log_Q_t_transitioning_to_known_class(
+ self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
+ ):
+ """
+ Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
+ latent pixel in `x_t`.
+
+ See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
+ is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
+
+ Args:
+ t (torch.Long):
+ The timestep that determines which transition matrix is used.
+
+ x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`.
+
+ log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
+ The log one-hot vectors of `x_t`
+
+ cumulative (`bool`):
+ If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
+ we use the cumulative transition matrix `0`->`t`.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
+ Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
+ transition matrix.
+
+ When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
+ masked.
+
+ Where:
+ - `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
+ - C_0 is a class of a latent pixel embedding
+ - C_k is the class of the masked latent pixel
+
+ non-cumulative result (omitting logarithms):
+ ```
+ q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
+ . . .
+ . . .
+ . . .
+ q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
+ ```
+
+ cumulative result (omitting logarithms):
+ ```
+ q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
+ . . .
+ . . .
+ . . .
+ q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
+ ```
+ """
+ if cumulative:
+ a = self.log_cumprod_at[t]
+ b = self.log_cumprod_bt[t]
+ c = self.log_cumprod_ct[t]
+ else:
+ a = self.log_at[t]
+ b = self.log_bt[t]
+ c = self.log_ct[t]
+
+ if not cumulative:
+ # The values in the onehot vector can also be used as the logprobs for transitioning
+ # from masked latent pixels. If we are not calculating the cumulative transitions,
+ # we need to save these vectors to be re-appended to the final matrix so the values
+ # aren't overwritten.
+ #
+ # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
+ # if x_t is not masked
+ #
+ # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
+ # if x_t is masked
+ log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
+
+ # `index_to_log_onehot` will add onehot vectors for masked pixels,
+ # so the default one hot matrix has one too many rows. See the doc string
+ # for an explanation of the dimensionality of the returned matrix.
+ log_onehot_x_t = log_onehot_x_t[:, :-1, :]
+
+ # this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
+ #
+ # Don't worry about what values this sets in the columns that mark transitions
+ # to masked latent pixels. They are overwrote later with the `mask_class_mask`.
+ #
+ # Looking at the below logspace formula in non-logspace, each value will evaluate to either
+ # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
+ # or
+ # `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
+ #
+ # See equation 7 for more details.
+ log_Q_t = (log_onehot_x_t + a).logaddexp(b)
+
+ # The whole column of each masked pixel is `c`
+ mask_class_mask = x_t == self.mask_class
+ mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
+ log_Q_t[mask_class_mask] = c
+
+ if not cumulative:
+ log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
+
+ return log_Q_t
+
+ def apply_cumulative_transitions(self, q, t):
+ bsz = q.shape[0]
+ a = self.log_cumprod_at[t]
+ b = self.log_cumprod_bt[t]
+ c = self.log_cumprod_ct[t]
+
+ num_latent_pixels = q.shape[2]
+ c = c.expand(bsz, 1, num_latent_pixels)
+
+ q = (q + a).logaddexp(b)
+ q = torch.cat((q, c), dim=1)
+
+ return q
diff --git a/diffusers/training_utils.py b/diffusers/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaa9ed64554bf8830e35efd220a77bd2de207f18
--- /dev/null
+++ b/diffusers/training_utils.py
@@ -0,0 +1,314 @@
+import contextlib
+import copy
+import random
+from typing import Any, Dict, Iterable, Optional, Union
+
+import numpy as np
+import torch
+
+from .utils import deprecate, is_transformers_available
+
+
+if is_transformers_available():
+ import transformers
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float = 0.9999,
+ min_decay: float = 0.0,
+ update_after_step: int = 0,
+ use_ema_warmup: bool = False,
+ inv_gamma: Union[float, int] = 1.0,
+ power: Union[float, int] = 2 / 3,
+ model_cls: Optional[Any] = None,
+ model_config: Dict[str, Any] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
+ decay (float): The decay factor for the exponential moving average.
+ min_decay (float): The minimum decay factor for the exponential moving average.
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
+ use_ema_warmup (bool): Whether to use EMA warmup.
+ inv_gamma (float):
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
+ weights will be stored on CPU.
+
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ """
+
+ if isinstance(parameters, torch.nn.Module):
+ deprecation_message = (
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
+ "Please pass the parameters of the module instead."
+ )
+ deprecate(
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ parameters = parameters.parameters()
+
+ # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
+ use_ema_warmup = True
+
+ if kwargs.get("max_value", None) is not None:
+ deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
+ deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
+ decay = kwargs["max_value"]
+
+ if kwargs.get("min_value", None) is not None:
+ deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
+ deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
+ min_decay = kwargs["min_value"]
+
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+
+ if kwargs.get("device", None) is not None:
+ deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
+ deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
+ self.to(device=kwargs["device"])
+
+ self.temp_stored_params = None
+
+ self.decay = decay
+ self.min_decay = min_decay
+ self.update_after_step = update_after_step
+ self.use_ema_warmup = use_ema_warmup
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.optimization_step = 0
+ self.cur_decay_value = None # set in `step()`
+
+ self.model_cls = model_cls
+ self.model_config = model_config
+
+ @classmethod
+ def from_pretrained(cls, path, model_cls) -> "EMAModel":
+ _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
+ model = model_cls.from_pretrained(path)
+
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
+
+ ema_model.load_state_dict(ema_kwargs)
+ return ema_model
+
+ def save_pretrained(self, path):
+ if self.model_cls is None:
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
+
+ if self.model_config is None:
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
+
+ model = self.model_cls.from_config(self.model_config)
+ state_dict = self.state_dict()
+ state_dict.pop("shadow_params", None)
+
+ model.register_to_config(**state_dict)
+ self.copy_to(model.parameters())
+ model.save_pretrained(path)
+
+ def get_decay(self, optimization_step: int) -> float:
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+
+ if step <= 0:
+ return 0.0
+
+ if self.use_ema_warmup:
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
+ else:
+ cur_decay_value = (1 + step) / (10 + step)
+
+ cur_decay_value = min(cur_decay_value, self.decay)
+ # make sure decay is not smaller than min_decay
+ cur_decay_value = max(cur_decay_value, self.min_decay)
+ return cur_decay_value
+
+ @torch.no_grad()
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
+ if isinstance(parameters, torch.nn.Module):
+ deprecation_message = (
+ "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
+ "Please pass the parameters of the module instead."
+ )
+ deprecate(
+ "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ parameters = parameters.parameters()
+
+ parameters = list(parameters)
+
+ self.optimization_step += 1
+
+ # Compute the decay factor for the exponential moving average.
+ decay = self.get_decay(self.optimization_step)
+ self.cur_decay_value = decay
+ one_minus_decay = 1 - decay
+
+ context_manager = contextlib.nullcontext
+ if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ for s_param, param in zip(self.shadow_params, parameters):
+ if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
+
+ with context_manager():
+ if param.requires_grad:
+ s_param.sub_(one_minus_decay * (s_param - param))
+ else:
+ s_param.copy_(param)
+
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ parameters = list(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.to(param.device).data)
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
+ def state_dict(self) -> dict:
+ r"""
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
+ checkpointing to save the ema state dict.
+ """
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "min_decay": self.min_decay,
+ "optimization_step": self.optimization_step,
+ "update_after_step": self.update_after_step,
+ "use_ema_warmup": self.use_ema_warmup,
+ "inv_gamma": self.inv_gamma,
+ "power": self.power,
+ "shadow_params": self.shadow_params,
+ }
+
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ r"""
+ Args:
+ Save the current parameters for restoring later.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
+
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ r"""
+ Args:
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
+ affecting the original optimization process. Store the parameters before the `copy_to()` method. After
+ validation (or model saving), use this to restore the former parameters.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the parameters with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ if self.temp_stored_params is None:
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
+ for c_param, param in zip(self.temp_stored_params, parameters):
+ param.data.copy_(c_param.data)
+
+ # Better memory-wise.
+ self.temp_stored_params = None
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""
+ Args:
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
+ ema state dict.
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+
+ self.decay = state_dict.get("decay", self.decay)
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
+ if not isinstance(self.min_decay, float):
+ raise ValueError("Invalid min_decay")
+
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
+ if not isinstance(self.optimization_step, int):
+ raise ValueError("Invalid optimization_step")
+
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
+ if not isinstance(self.update_after_step, int):
+ raise ValueError("Invalid update_after_step")
+
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
+ if not isinstance(self.use_ema_warmup, bool):
+ raise ValueError("Invalid use_ema_warmup")
+
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
+ if not isinstance(self.inv_gamma, (float, int)):
+ raise ValueError("Invalid inv_gamma")
+
+ self.power = state_dict.get("power", self.power)
+ if not isinstance(self.power, (float, int)):
+ raise ValueError("Invalid power")
+
+ shadow_params = state_dict.get("shadow_params", None)
+ if shadow_params is not None:
+ self.shadow_params = shadow_params
+ if not isinstance(self.shadow_params, list):
+ raise ValueError("shadow_params must be a list")
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
+ raise ValueError("shadow_params must all be Tensors")
diff --git a/diffusers/utils/__init__.py b/diffusers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98fac64497e7efa4a881124dd778c4f3084402e8
--- /dev/null
+++ b/diffusers/utils/__init__.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from packaging import version
+
+from .. import __version__
+from .accelerate_utils import apply_forward_hook
+from .constants import (
+ CONFIG_NAME,
+ DEPRECATED_REVISION_ARGS,
+ DIFFUSERS_CACHE,
+ DIFFUSERS_DYNAMIC_MODULE_NAME,
+ FLAX_WEIGHTS_NAME,
+ HF_MODULES_CACHE,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ ONNX_EXTERNAL_WEIGHTS_NAME,
+ ONNX_WEIGHTS_NAME,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+)
+from .deprecation_utils import deprecate
+from .doc_utils import replace_example_docstring
+from .dynamic_modules_utils import get_class_from_dynamic_module
+from .hub_utils import (
+ HF_HUB_OFFLINE,
+ _add_variant,
+ _get_model_file,
+ extract_commit_hash,
+ http_user_agent,
+)
+from .import_utils import (
+ BACKENDS_MAPPING,
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ DummyObject,
+ OptionalDependencyNotAvailable,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bs4_available,
+ is_flax_available,
+ is_ftfy_available,
+ is_inflect_available,
+ is_invisible_watermark_available,
+ is_k_diffusion_available,
+ is_k_diffusion_version,
+ is_librosa_available,
+ is_note_seq_available,
+ is_omegaconf_available,
+ is_onnx_available,
+ is_safetensors_available,
+ is_scipy_available,
+ is_tensorboard_available,
+ is_tf_available,
+ is_torch_available,
+ is_torch_version,
+ is_torchsde_available,
+ is_transformers_available,
+ is_transformers_version,
+ is_unidecode_available,
+ is_wandb_available,
+ is_xformers_available,
+ requires_backends,
+)
+from .logging import get_logger
+from .outputs import BaseOutput
+from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil
+from .torch_utils import is_compiled_module, randn_tensor
+
+
+if is_torch_available():
+ from .testing_utils import (
+ floats_tensor,
+ load_hf_numpy,
+ load_image,
+ load_numpy,
+ load_pt,
+ nightly,
+ parse_flag_from_env,
+ print_tensor_test,
+ require_torch_2,
+ require_torch_gpu,
+ skip_mps,
+ slow,
+ torch_all_close,
+ torch_device,
+ )
+ from .torch_utils import maybe_allow_in_graph
+
+from .testing_utils import export_to_gif, export_to_video
+
+
+logger = get_logger(__name__)
+
+
+def check_min_version(min_version):
+ if version.parse(__version__) < version.parse(min_version):
+ if "dev" in min_version:
+ error_message = (
+ "This example requires a source install from HuggingFace diffusers (see "
+ "`https://huggingface.co/docs/diffusers/installation#install-from-source`),"
+ )
+ else:
+ error_message = f"This example requires a minimum version of {min_version},"
+ error_message += f" but the version found is {__version__}.\n"
+ raise ImportError(error_message)
diff --git a/diffusers/utils/__pycache__/__init__.cpython-310.pyc b/diffusers/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e310135040760e678de97f01b045e5b456e2892f
Binary files /dev/null and b/diffusers/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/__init__.cpython-38.pyc b/diffusers/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ed1374b0d9691344cd9a04cae2ef91d26423a79
Binary files /dev/null and b/diffusers/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc b/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10d05a0d26ad25743da806d07a163c95c3880789
Binary files /dev/null and b/diffusers/utils/__pycache__/accelerate_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc b/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8dab26b80c017ad19e911dc0890f7dd85d68417
Binary files /dev/null and b/diffusers/utils/__pycache__/accelerate_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/constants.cpython-310.pyc b/diffusers/utils/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e980e4d8349a7c3ea515179fceff00ca1b2a1ee8
Binary files /dev/null and b/diffusers/utils/__pycache__/constants.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/constants.cpython-38.pyc b/diffusers/utils/__pycache__/constants.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76bdcbfd60909c4a3303eb95f69fb0e7ec435293
Binary files /dev/null and b/diffusers/utils/__pycache__/constants.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc b/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df4b7753e7404d9cc591e339b470637a2917083f
Binary files /dev/null and b/diffusers/utils/__pycache__/deprecation_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc b/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de58ba991d9ae78013c1553d731bf03eceef760a
Binary files /dev/null and b/diffusers/utils/__pycache__/deprecation_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc b/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f129c84c0fbbac511ae7533cef7b5be4056350b
Binary files /dev/null and b/diffusers/utils/__pycache__/doc_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc b/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28f4c6ec6bba7a5d925f5d32bd115faf79f1943f
Binary files /dev/null and b/diffusers/utils/__pycache__/doc_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba514a0e2bf5d2980575b28a242d7599aaa6bdb5
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cef6bd6fdbbcaedf57f4dc2c181861fab7a614db
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5e68c7855ea036ae74664d258b3ff82db52d9a4
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1316c453f9831487ced9d28fc940060524c0161
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_flax_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcf2cff2c2a1b5fff7243c697ff3353659f534af
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f632fbc289333a27a2799923f0e7b60697e9715c
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_note_seq_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72149f1f4a5b1f34f42febc8821681a82d72bc30
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c93d3c999f715c5afc190d9baf886dc299c90770
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_onnx_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fad3b97c7874d7193b90dda87912a25e3125e07
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03d85bc5dec6fb476bf950a08fe0dd7f3489e467
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10f5a6e9146ff91bab5fefcca698cb085f8ba6f6
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..442526b89d7d4aeaaaa2552df0c427132070b18d
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_torchsde_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a028ce8b900009303ca9eae3e8023a986bf97519
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_invisible_watermark_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e20efeffe6da968d906a44cae8a341ca1aceb0b
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d88c82928776cf41bc2c36f72939a665ac2c4b3f
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..987ccbaa93fdf4a23aaeaa1d1521d7c7c96e83ef
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e35e45dabcf2437f7800bebd4d6d781da9c367d1
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..157e64060d57b362f1f6013262a64f8bf1bd37bc
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe8d329eec939b64e5ffeca2a939beb5e3d8d990
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caf5da69df96565b1a708a718e2c1a8de07970ab
Binary files /dev/null and b/diffusers/utils/__pycache__/dummy_transformers_and_torch_and_note_seq_objects.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d7927f515b35ab0785380733985c90ff47ed3e1
Binary files /dev/null and b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60c4b66fe6c3b949f0d3add87a4a13cb42aa5a5d
Binary files /dev/null and b/diffusers/utils/__pycache__/dynamic_modules_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc b/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..509d816b91cdcaa9f8293c73d938994d7b2d4442
Binary files /dev/null and b/diffusers/utils/__pycache__/hub_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc b/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2d85317db4cb3fb29fce9981f3c34c4940a039b
Binary files /dev/null and b/diffusers/utils/__pycache__/hub_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/import_utils.cpython-310.pyc b/diffusers/utils/__pycache__/import_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74b8269dff3936b8b5f441bc8e02b70ca2a12363
Binary files /dev/null and b/diffusers/utils/__pycache__/import_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/import_utils.cpython-38.pyc b/diffusers/utils/__pycache__/import_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8eac94029ce0f02690721de3e96a431bdb422ff
Binary files /dev/null and b/diffusers/utils/__pycache__/import_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/logging.cpython-310.pyc b/diffusers/utils/__pycache__/logging.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f071c47dcbc9316b467b4010e4fe33015f95d93
Binary files /dev/null and b/diffusers/utils/__pycache__/logging.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/logging.cpython-38.pyc b/diffusers/utils/__pycache__/logging.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e29dc506db303bf13c90171387c3f119f0fb7f3
Binary files /dev/null and b/diffusers/utils/__pycache__/logging.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/outputs.cpython-310.pyc b/diffusers/utils/__pycache__/outputs.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2acb989510b087333ed6bc924e0df165bfe1ae1
Binary files /dev/null and b/diffusers/utils/__pycache__/outputs.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/outputs.cpython-38.pyc b/diffusers/utils/__pycache__/outputs.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8df1d6a0faf156add97ecd8c7238f40d886d7982
Binary files /dev/null and b/diffusers/utils/__pycache__/outputs.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc b/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05b64bee2f2ab9152d2ce1b1018af376553ceedc
Binary files /dev/null and b/diffusers/utils/__pycache__/pil_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc b/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b655c24119aab52e8b361b6084c91ce8f5b9926
Binary files /dev/null and b/diffusers/utils/__pycache__/pil_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc b/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d703eef3788601acccc5c34c642d16d51fc875db
Binary files /dev/null and b/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc b/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fa118df35a47eb9424e41a7ee687586ea5adb65
Binary files /dev/null and b/diffusers/utils/__pycache__/testing_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc b/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79e119de6e598661628ff58b395c8e215d096f97
Binary files /dev/null and b/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc differ
diff --git a/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc b/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e907758a309392e44565a1f44f19a7726f99a833
Binary files /dev/null and b/diffusers/utils/__pycache__/torch_utils.cpython-38.pyc differ
diff --git a/diffusers/utils/accelerate_utils.py b/diffusers/utils/accelerate_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a83e1dd209cca198f4038d0d7e7228f9671859
--- /dev/null
+++ b/diffusers/utils/accelerate_utils.py
@@ -0,0 +1,48 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Accelerate utilities: Utilities related to accelerate
+"""
+
+from packaging import version
+
+from .import_utils import is_accelerate_available
+
+
+if is_accelerate_available():
+ import accelerate
+
+
+def apply_forward_hook(method):
+ """
+ Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful
+ for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the
+ appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`].
+
+ This decorator looks inside the internal `_hf_hook` property to find a registered offload hook.
+
+ :param method: The method to decorate. This method should be a method of a PyTorch module.
+ """
+ if not is_accelerate_available():
+ return method
+ accelerate_version = version.parse(accelerate.__version__).base_version
+ if version.parse(accelerate_version) < version.parse("0.17.0"):
+ return method
+
+ def wrapper(self, *args, **kwargs):
+ if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):
+ self._hf_hook.pre_forward(self)
+ return method(self, *args, **kwargs)
+
+ return wrapper
diff --git a/diffusers/utils/constants.py b/diffusers/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e60a2a873b29a7d3adffbd7179be1670b3b417
--- /dev/null
+++ b/diffusers/utils/constants.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
+
+
+default_cache_path = HUGGINGFACE_HUB_CACHE
+
+
+CONFIG_NAME = "config.json"
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
+ONNX_WEIGHTS_NAME = "model.onnx"
+SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
+ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+DIFFUSERS_CACHE = default_cache_path
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
+DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
diff --git a/diffusers/utils/deprecation_utils.py b/diffusers/utils/deprecation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f482deddd2f46b8d2e29d5229faa0e9a21f2fd98
--- /dev/null
+++ b/diffusers/utils/deprecation_utils.py
@@ -0,0 +1,49 @@
+import inspect
+import warnings
+from typing import Any, Dict, Optional, Union
+
+from packaging import version
+
+
+def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
+ from .. import __version__
+
+ deprecated_kwargs = take_from
+ values = ()
+ if not isinstance(args[0], tuple):
+ args = (args,)
+
+ for attribute, version_name, message in args:
+ if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
+ raise ValueError(
+ f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
+ f" version {__version__} is >= {version_name}"
+ )
+
+ warning = None
+ if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
+ values += (deprecated_kwargs.pop(attribute),)
+ warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
+ elif hasattr(deprecated_kwargs, attribute):
+ values += (getattr(deprecated_kwargs, attribute),)
+ warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
+ elif deprecated_kwargs is None:
+ warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
+
+ if warning is not None:
+ warning = warning + " " if standard_warn else ""
+ warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
+
+ if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
+ call_frame = inspect.getouterframes(inspect.currentframe())[1]
+ filename = call_frame.filename
+ line_number = call_frame.lineno
+ function = call_frame.function
+ key, value = next(iter(deprecated_kwargs.items()))
+ raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+
+ if len(values) == 0:
+ return
+ elif len(values) == 1:
+ return values[0]
+ return values
diff --git a/diffusers/utils/doc_utils.py b/diffusers/utils/doc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1f87743f99802931334bd51bf99985775116d59
--- /dev/null
+++ b/diffusers/utils/doc_utils.py
@@ -0,0 +1,38 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Doc utilities: Utilities related to documentation
+"""
+import re
+
+
+def replace_example_docstring(example_docstring):
+ def docstring_decorator(fn):
+ func_doc = fn.__doc__
+ lines = func_doc.split("\n")
+ i = 0
+ while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None:
+ i += 1
+ if i < len(lines):
+ lines[i] = example_docstring
+ func_doc = "\n".join(lines)
+ else:
+ raise ValueError(
+ f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, "
+ f"current docstring is:\n{func_doc}"
+ )
+ fn.__doc__ = func_doc
+ return fn
+
+ return docstring_decorator
diff --git a/diffusers/utils/dummy_flax_and_transformers_objects.py b/diffusers/utils/dummy_flax_and_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..162bac1c4331149c4b5abde1eadd8013ab0cda99
--- /dev/null
+++ b/diffusers/utils/dummy_flax_and_transformers_objects.py
@@ -0,0 +1,62 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+
+class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+
+class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+
+class FlaxStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
diff --git a/diffusers/utils/dummy_flax_objects.py b/diffusers/utils/dummy_flax_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bb80d136f338d193c67773266355956afd1d98a
--- /dev/null
+++ b/diffusers/utils/dummy_flax_objects.py
@@ -0,0 +1,197 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class FlaxControlNetModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxModelMixin(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxUNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxAutoencoderKL(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDDIMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDDPMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxKarrasVeScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxPNDMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxSchedulerMixin(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
diff --git a/diffusers/utils/dummy_note_seq_objects.py b/diffusers/utils/dummy_note_seq_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02d0b015aedc37c01fb3b843bc79547aae5da68
--- /dev/null
+++ b/diffusers/utils/dummy_note_seq_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class MidiProcessor(metaclass=DummyObject):
+ _backends = ["note_seq"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["note_seq"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["note_seq"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["note_seq"])
diff --git a/diffusers/utils/dummy_onnx_objects.py b/diffusers/utils/dummy_onnx_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde5f6ad0793e2d81bc638600b46ff81748d09ee
--- /dev/null
+++ b/diffusers/utils/dummy_onnx_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class OnnxRuntimeModel(metaclass=DummyObject):
+ _backends = ["onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["onnx"])
diff --git a/diffusers/utils/dummy_pt_objects.py b/diffusers/utils/dummy_pt_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955ec5320de4973a4e8eaf2f039953189d83228
--- /dev/null
+++ b/diffusers/utils/dummy_pt_objects.py
@@ -0,0 +1,810 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class AutoencoderKL(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ModelMixin(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class MultiAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PriorTransformer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class T2IAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class T5FilmDecoder(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet1DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet3DConditionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class VQModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+def get_constant_schedule(*args, **kwargs):
+ requires_backends(get_constant_schedule, ["torch"])
+
+
+def get_constant_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_constant_schedule_with_warmup, ["torch"])
+
+
+def get_cosine_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_cosine_schedule_with_warmup, ["torch"])
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
+
+
+def get_linear_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_linear_schedule_with_warmup, ["torch"])
+
+
+def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
+
+
+def get_scheduler(*args, **kwargs):
+ requires_backends(get_scheduler, ["torch"])
+
+
+class AudioPipelineOutput(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ConsistencyModelPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DanceDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDPMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DiTPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ImagePipelineOutput(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KarrasVePipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class LDMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class LDMSuperResolutionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PNDMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class RePaintPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ScoreSdeVePipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class CMStochasticIterativeScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMInverseScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMParallelScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDPMParallelScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDPMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DEISMultistepScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DPMSolverMultistepInverseScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DPMSolverMultistepScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DPMSolverSinglestepScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EulerDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class HeunDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class IPNDMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KarrasVeScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KDPM2DiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PNDMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class RePaintScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SchedulerMixin(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ScoreSdeVeScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UnCLIPScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UniPCMultistepScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class VQDiffusionScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EMAModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
diff --git a/diffusers/utils/dummy_torch_and_librosa_objects.py b/diffusers/utils/dummy_torch_and_librosa_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..2088bc4a744198284f22fe54e6f1055cf3568566
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_librosa_objects.py
@@ -0,0 +1,32 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class AudioDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "librosa"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "librosa"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "librosa"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "librosa"])
+
+
+class Mel(metaclass=DummyObject):
+ _backends = ["torch", "librosa"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "librosa"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "librosa"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "librosa"])
diff --git a/diffusers/utils/dummy_torch_and_scipy_objects.py b/diffusers/utils/dummy_torch_and_scipy_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ff25863822b04971d2c6dfdc17f5b28774cf05
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_scipy_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class LMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch", "scipy"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "scipy"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "scipy"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "scipy"])
diff --git a/diffusers/utils/dummy_torch_and_torchsde_objects.py b/diffusers/utils/dummy_torch_and_torchsde_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..a81bbb316f32267c31b06598519f1eef9ddde643
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_torchsde_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class DPMSolverSDEScheduler(metaclass=DummyObject):
+ _backends = ["torch", "torchsde"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "torchsde"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "torchsde"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "torchsde"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4d59c537fa3dbaa76699836d9024823c4e36bd
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py
@@ -0,0 +1,62 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "invisible_watermark"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+
+class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "invisible_watermark"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+
+class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "invisible_watermark"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+
+class StableDiffusionXLPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "invisible_watermark"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..56836f0b6d77b8daa25e956101694863e418339f
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class StableDiffusionKDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "k_diffusion"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "k_diffusion"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "k_diffusion"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "k_diffusion"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7afad8226b87292100270e3e7daad6885be0e7f
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
@@ -0,0 +1,92 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class StableDiffusionOnnxPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_objects.py b/diffusers/utils/dummy_torch_and_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..016760337c696890e1ec7033c00129d195d8f6a3
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -0,0 +1,962 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AltDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDMPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CycleDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFInpaintingPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class IFSuperResolutionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ImageTextPipelineOutput(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyPriorPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22ControlnetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22Img2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22InpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class KandinskyV22PriorPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class LDMTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class PaintByExamplePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class SemanticStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ShapEImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ShapEPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionAdapterPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionDiffEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionLDM3DPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionModelEditingPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPanoramaPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionParadigmsPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPipelineSafe(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionSAGPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionUpscalePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableUnCLIPPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class TextToVideoSDPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class TextToVideoZeroPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class UnCLIPImageVariationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class UnCLIPPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class UniDiffuserModel(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class UniDiffuserPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class UniDiffuserTextDecoder(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VideoToVideoSDPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VQDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
diff --git a/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py b/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbde04e33f0abd86d12f3dee048a4f0585c9f19d
--- /dev/null
+++ b/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class SpectrogramDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "torch", "note_seq"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "torch", "note_seq"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["transformers", "torch", "note_seq"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["transformers", "torch", "note_seq"])
diff --git a/diffusers/utils/dynamic_modules_utils.py b/diffusers/utils/dynamic_modules_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0952f0b514cb52e63fdac8a780ddc9482a5b9d
--- /dev/null
+++ b/diffusers/utils/dynamic_modules_utils.py
@@ -0,0 +1,456 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import inspect
+import json
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+from urllib import request
+
+from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
+from packaging import version
+
+from .. import __version__
+from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+
+
+COMMUNITY_PIPELINES_URL = (
+ "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py"
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_diffusers_versions():
+ url = "https://pypi.org/pypi/diffusers/json"
+ releases = json.loads(request.urlopen(url).read())["releases"].keys()
+ return sorted(releases, key=lambda x: version.Version(x))
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+
+ if class_name is None:
+ return find_pipeline_class(module)
+ return getattr(module, class_name)
+
+
+def find_pipeline_class(loaded_module):
+ """
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
+ inheriting from `DiffusionPipeline`.
+ """
+ from ..pipelines import DiffusionPipeline
+
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
+
+ pipeline_class = None
+ for cls_name, cls in cls_members.items():
+ if (
+ cls_name != DiffusionPipeline.__name__
+ and issubclass(cls, DiffusionPipeline)
+ and cls.__module__.split(".")[0] != "diffusers"
+ ):
+ if pipeline_class is not None:
+ raise ValueError(
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
+ f" {loaded_module}."
+ )
+ pipeline_class = cls
+
+ return pipeline_class
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+
+ if os.path.isfile(module_file_or_url):
+ resolved_module_file = module_file_or_url
+ submodule = "local"
+ elif pretrained_model_name_or_path.count("/") == 0:
+ available_versions = get_diffusers_versions()
+ # cut ".dev0"
+ latest_version = "v" + ".".join(__version__.split(".")[:3])
+
+ # retrieve github version that matches
+ if revision is None:
+ revision = latest_version if latest_version[1:] in available_versions else "main"
+ logger.info(f"Defaulting to latest_version: {revision}.")
+ elif revision in available_versions:
+ revision = f"v{revision}"
+ elif revision == "main":
+ revision = revision
+ else:
+ raise ValueError(
+ f"`custom_revision`: {revision} does not exist. Please make sure to choose one of"
+ f" {', '.join(available_versions + ['main'])}."
+ )
+
+ # community pipeline on GitHub
+ github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path)
+ try:
+ resolved_module_file = cached_download(
+ github_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=False,
+ )
+ submodule = "git"
+ module_file = pretrained_model_name_or_path + ".py"
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ if submodule == "local" or submodule == "git":
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ else:
+ # Get the commit hash
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
+ if isinstance(use_auth_token, str):
+ token = use_auth_token
+ elif use_auth_token is True:
+ token = HfFolder.get_token()
+ else:
+ token = None
+
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ create_dynamic_module(full_submodule)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not (submodule_path / module_needed).exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: Optional[str] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/diffusers/utils/hub_utils.py b/diffusers/utils/hub_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f0cf00a5c5d0d303ba53f62fbf027c0bc31ad49
--- /dev/null
+++ b/diffusers/utils/hub_utils.py
@@ -0,0 +1,361 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import re
+import sys
+import traceback
+import warnings
+from pathlib import Path
+from typing import Dict, Optional, Union
+from uuid import uuid4
+
+from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
+from huggingface_hub.file_download import REGEX_COMMIT_HASH
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ is_jinja_available,
+)
+from packaging import version
+from requests import HTTPError
+
+from .. import __version__
+from .constants import (
+ DEPRECATED_REVISION_ARGS,
+ DIFFUSERS_CACHE,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+)
+from .import_utils import (
+ ENV_VARS_TRUE_VALUES,
+ _flax_version,
+ _jax_version,
+ _onnxruntime_version,
+ _torch_version,
+ is_flax_available,
+ is_onnx_available,
+ is_torch_available,
+)
+from .logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
+SESSION_ID = uuid4().hex
+HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
+DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
+HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
+
+
+def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
+ """
+ Formats a user-agent string with basic info about a request.
+ """
+ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
+ if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
+ return ua + "; telemetry/off"
+ if is_torch_available():
+ ua += f"; torch/{_torch_version}"
+ if is_flax_available():
+ ua += f"; jax/{_jax_version}"
+ ua += f"; flax/{_flax_version}"
+ if is_onnx_available():
+ ua += f"; onnxruntime/{_onnxruntime_version}"
+ # CI will set this value to True
+ if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
+ ua += "; is_ci/true"
+ if isinstance(user_agent, dict):
+ ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
+ elif isinstance(user_agent, str):
+ ua += "; " + user_agent
+ return ua
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def create_model_card(args, model_name):
+ if not is_jinja_available():
+ raise ValueError(
+ "Modelcard rendering is based on Jinja templates."
+ " Please make sure to have `jinja` installed before using `create_model_card`."
+ " To install it, please run `pip install Jinja2`."
+ )
+
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ repo_name = get_full_repo_name(model_name, token=hub_token)
+
+ model_card = ModelCard.from_template(
+ card_data=ModelCardData( # Card metadata object that will be converted to YAML block
+ language="en",
+ license="apache-2.0",
+ library_name="diffusers",
+ tags=[],
+ datasets=args.dataset_name,
+ metrics=[],
+ ),
+ template_path=MODEL_CARD_TEMPLATE_PATH,
+ model_name=model_name,
+ repo_name=repo_name,
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
+ learning_rate=args.learning_rate,
+ train_batch_size=args.train_batch_size,
+ eval_batch_size=args.eval_batch_size,
+ gradient_accumulation_steps=(
+ args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
+ ),
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
+ mixed_precision=args.mixed_precision,
+ )
+
+ card_path = os.path.join(args.output_dir, "README.md")
+ model_card.save(card_path)
+
+
+def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
+ """
+ Extracts the commit hash from a resolved filename toward a cache file.
+ """
+ if resolved_file is None or commit_hash is not None:
+ return commit_hash
+ resolved_file = str(Path(resolved_file).as_posix())
+ search = re.search(r"snapshots/([^/]+)/", resolved_file)
+ if search is None:
+ return None
+ commit_hash = search.groups()[0]
+ return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
+
+
+# Old default cache path, potentially to be migrated.
+# This logic was more or less taken from `transformers`, with the following differences:
+# - Diffusers doesn't use custom environment variables to specify the cache path.
+# - There is no need to migrate the cache format, just move the files to the new location.
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
+
+
+def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
+ if new_cache_dir is None:
+ new_cache_dir = DIFFUSERS_CACHE
+ if old_cache_dir is None:
+ old_cache_dir = old_diffusers_cache
+
+ old_cache_dir = Path(old_cache_dir).expanduser()
+ new_cache_dir = Path(new_cache_dir).expanduser()
+ for old_blob_path in old_cache_dir.glob("**/blobs/*"):
+ if old_blob_path.is_file() and not old_blob_path.is_symlink():
+ new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
+ new_blob_path.parent.mkdir(parents=True, exist_ok=True)
+ os.replace(old_blob_path, new_blob_path)
+ try:
+ os.symlink(new_blob_path, old_blob_path)
+ except OSError:
+ logger.warning(
+ "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
+ )
+ # At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
+
+
+cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
+if not os.path.isfile(cache_version_file):
+ cache_version = 0
+else:
+ with open(cache_version_file) as f:
+ try:
+ cache_version = int(f.read())
+ except ValueError:
+ cache_version = 0
+
+if cache_version < 1:
+ old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
+ if old_cache_is_not_empty:
+ logger.warning(
+ "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
+ "existing cached models. This is a one-time operation, you can interrupt it or run it "
+ "later by calling `diffusers.utils.hub_utils.move_cache()`."
+ )
+ try:
+ move_cache()
+ except Exception as e:
+ trace = "\n".join(traceback.format_tb(e.__traceback__))
+ logger.error(
+ f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
+ "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
+ "message and we will do our best to help."
+ )
+
+if cache_version < 1:
+ try:
+ os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
+ with open(cache_version_file, "w") as f:
+ f.write("1")
+ except Exception:
+ logger.warning(
+ f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
+ "the directory exists and can be written to."
+ )
+
+
+def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
+ if variant is not None:
+ splits = weights_name.split(".")
+ splits = splits[:-1] + [variant] + splits[-1:]
+ weights_name = ".".join(splits)
+
+ return weights_name
+
+
+def _get_model_file(
+ pretrained_model_name_or_path,
+ *,
+ weights_name,
+ subfolder,
+ cache_dir,
+ force_download,
+ proxies,
+ resume_download,
+ local_files_only,
+ use_auth_token,
+ user_agent,
+ revision,
+ commit_hash=None,
+):
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
+ return pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
+ return model_file
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ return model_file
+ else:
+ raise EnvironmentError(
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ # 1. First check if deprecated way of loading from branches is used
+ if (
+ revision in DEPRECATED_REVISION_ARGS
+ and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
+ and version.parse(version.parse(__version__).base_version) >= version.parse("0.20.0")
+ ):
+ try:
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=_add_variant(weights_name, revision),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision or commit_hash,
+ )
+ warnings.warn(
+ f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
+ FutureWarning,
+ )
+ return model_file
+ except: # noqa: E722
+ warnings.warn(
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
+ FutureWarning,
+ )
+ try:
+ # 2. Load model file as usual
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=weights_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision or commit_hash,
+ )
+ return model_file
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {weights_name} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {weights_name}"
+ )
diff --git a/diffusers/utils/import_utils.py b/diffusers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f40d939a71194ee4325964838640f142fd1c964
--- /dev/null
+++ b/diffusers/utils/import_utils.py
@@ -0,0 +1,655 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+import importlib.util
+import operator as op
+import os
+import sys
+from collections import OrderedDict
+from typing import Union
+
+from huggingface_hub.utils import is_jinja_available # noqa: F401
+from packaging import version
+from packaging.version import Version, parse
+
+from . import logging
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
+USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
+
+STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
+
+_torch_version = "N/A"
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logger.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _torch_available = False
+else:
+ logger.info("Disabling PyTorch because USE_TORCH is set")
+ _torch_available = False
+
+
+_tf_version = "N/A"
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "intel-tensorflow-avx512",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ "tensorflow-aarch64",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if version.parse(_tf_version) < version.parse("2"):
+ logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logger.info(f"TensorFlow version {_tf_version} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+_jax_version = "N/A"
+_flax_version = "N/A"
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
+ if _flax_available:
+ try:
+ _jax_version = importlib_metadata.version("jax")
+ _flax_version = importlib_metadata.version("flax")
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _flax_available = False
+else:
+ _flax_available = False
+
+if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _safetensors_available = importlib.util.find_spec("safetensors") is not None
+ if _safetensors_available:
+ try:
+ _safetensors_version = importlib_metadata.version("safetensors")
+ logger.info(f"Safetensors version {_safetensors_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _safetensors_available = False
+else:
+ logger.info("Disabling Safetensors because USE_TF is set")
+ _safetensors_available = False
+
+_transformers_available = importlib.util.find_spec("transformers") is not None
+try:
+ _transformers_version = importlib_metadata.version("transformers")
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _transformers_available = False
+
+
+_inflect_available = importlib.util.find_spec("inflect") is not None
+try:
+ _inflect_version = importlib_metadata.version("inflect")
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
+except importlib_metadata.PackageNotFoundError:
+ _inflect_available = False
+
+
+_unidecode_available = importlib.util.find_spec("unidecode") is not None
+try:
+ _unidecode_version = importlib_metadata.version("unidecode")
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
+except importlib_metadata.PackageNotFoundError:
+ _unidecode_available = False
+
+
+_onnxruntime_version = "N/A"
+_onnx_available = importlib.util.find_spec("onnxruntime") is not None
+if _onnx_available:
+ candidates = (
+ "onnxruntime",
+ "onnxruntime-gpu",
+ "ort_nightly_gpu",
+ "onnxruntime-directml",
+ "onnxruntime-openvino",
+ "ort_nightly_directml",
+ "onnxruntime-rocm",
+ "onnxruntime-training",
+ )
+ _onnxruntime_version = None
+ # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
+ for pkg in candidates:
+ try:
+ _onnxruntime_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _onnx_available = _onnxruntime_version is not None
+ if _onnx_available:
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
+
+# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
+# _opencv_available = importlib.util.find_spec("opencv-python") is not None
+try:
+ candidates = (
+ "opencv-python",
+ "opencv-contrib-python",
+ "opencv-python-headless",
+ "opencv-contrib-python-headless",
+ )
+ _opencv_version = None
+ for pkg in candidates:
+ try:
+ _opencv_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _opencv_available = _opencv_version is not None
+ if _opencv_available:
+ logger.debug(f"Successfully imported cv2 version {_opencv_version}")
+except importlib_metadata.PackageNotFoundError:
+ _opencv_available = False
+
+_scipy_available = importlib.util.find_spec("scipy") is not None
+try:
+ _scipy_version = importlib_metadata.version("scipy")
+ logger.debug(f"Successfully imported scipy version {_scipy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _scipy_available = False
+
+_librosa_available = importlib.util.find_spec("librosa") is not None
+try:
+ _librosa_version = importlib_metadata.version("librosa")
+ logger.debug(f"Successfully imported librosa version {_librosa_version}")
+except importlib_metadata.PackageNotFoundError:
+ _librosa_available = False
+
+_accelerate_available = importlib.util.find_spec("accelerate") is not None
+try:
+ _accelerate_version = importlib_metadata.version("accelerate")
+ logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
+except importlib_metadata.PackageNotFoundError:
+ _accelerate_available = False
+
+_xformers_available = importlib.util.find_spec("xformers") is not None
+try:
+ _xformers_version = importlib_metadata.version("xformers")
+ if _torch_available:
+ import torch
+
+ if version.Version(torch.__version__) < version.Version("1.12"):
+ raise ValueError("PyTorch should be >= 1.12")
+ logger.debug(f"Successfully imported xformers version {_xformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _xformers_available = False
+
+_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
+try:
+ _k_diffusion_version = importlib_metadata.version("k_diffusion")
+ logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
+except importlib_metadata.PackageNotFoundError:
+ _k_diffusion_available = False
+
+_note_seq_available = importlib.util.find_spec("note_seq") is not None
+try:
+ _note_seq_version = importlib_metadata.version("note_seq")
+ logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
+except importlib_metadata.PackageNotFoundError:
+ _note_seq_available = False
+
+_wandb_available = importlib.util.find_spec("wandb") is not None
+try:
+ _wandb_version = importlib_metadata.version("wandb")
+ logger.debug(f"Successfully imported wandb version {_wandb_version }")
+except importlib_metadata.PackageNotFoundError:
+ _wandb_available = False
+
+_omegaconf_available = importlib.util.find_spec("omegaconf") is not None
+try:
+ _omegaconf_version = importlib_metadata.version("omegaconf")
+ logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}")
+except importlib_metadata.PackageNotFoundError:
+ _omegaconf_available = False
+
+_tensorboard_available = importlib.util.find_spec("tensorboard")
+try:
+ _tensorboard_version = importlib_metadata.version("tensorboard")
+ logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
+except importlib_metadata.PackageNotFoundError:
+ _tensorboard_available = False
+
+
+_compel_available = importlib.util.find_spec("compel")
+try:
+ _compel_version = importlib_metadata.version("compel")
+ logger.debug(f"Successfully imported compel version {_compel_version}")
+except importlib_metadata.PackageNotFoundError:
+ _compel_available = False
+
+
+_ftfy_available = importlib.util.find_spec("ftfy") is not None
+try:
+ _ftfy_version = importlib_metadata.version("ftfy")
+ logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _ftfy_available = False
+
+
+_bs4_available = importlib.util.find_spec("bs4") is not None
+try:
+ # importlib metadata under different name
+ _bs4_version = importlib_metadata.version("beautifulsoup4")
+ logger.debug(f"Successfully imported ftfy version {_bs4_version}")
+except importlib_metadata.PackageNotFoundError:
+ _bs4_available = False
+
+_torchsde_available = importlib.util.find_spec("torchsde") is not None
+try:
+ _torchsde_version = importlib_metadata.version("torchsde")
+ logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
+except importlib_metadata.PackageNotFoundError:
+ _torchsde_available = False
+
+_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
+try:
+ _invisible_watermark_version = importlib_metadata.version("invisible-watermark")
+ logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
+except importlib_metadata.PackageNotFoundError:
+ _invisible_watermark_available = False
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_safetensors_available():
+ return _safetensors_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_inflect_available():
+ return _inflect_available
+
+
+def is_unidecode_available():
+ return _unidecode_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_opencv_available():
+ return _opencv_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+def is_librosa_available():
+ return _librosa_available
+
+
+def is_xformers_available():
+ return _xformers_available
+
+
+def is_accelerate_available():
+ return _accelerate_available
+
+
+def is_k_diffusion_available():
+ return _k_diffusion_available
+
+
+def is_note_seq_available():
+ return _note_seq_available
+
+
+def is_wandb_available():
+ return _wandb_available
+
+
+def is_omegaconf_available():
+ return _omegaconf_available
+
+
+def is_tensorboard_available():
+ return _tensorboard_available
+
+
+def is_compel_available():
+ return _compel_available
+
+
+def is_ftfy_available():
+ return _ftfy_available
+
+
+def is_bs4_available():
+ return _bs4_available
+
+
+def is_torchsde_available():
+ return _torchsde_available
+
+
+def is_invisible_watermark_available():
+ return _invisible_watermark_available
+
+
+# docstyle-ignore
+FLAX_IMPORT_ERROR = """
+{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
+installation page: https://github.com/google/flax and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+INFLECT_IMPORT_ERROR = """
+{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
+inflect`
+"""
+
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR = """
+{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
+installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+ONNX_IMPORT_ERROR = """
+{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
+install onnxruntime`
+"""
+
+# docstyle-ignore
+OPENCV_IMPORT_ERROR = """
+{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
+install opencv-python`
+"""
+
+# docstyle-ignore
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
+scipy`
+"""
+
+# docstyle-ignore
+LIBROSA_IMPORT_ERROR = """
+{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the
+installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+TRANSFORMERS_IMPORT_ERROR = """
+{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
+install transformers`
+"""
+
+# docstyle-ignore
+UNIDECODE_IMPORT_ERROR = """
+{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
+Unidecode`
+"""
+
+# docstyle-ignore
+K_DIFFUSION_IMPORT_ERROR = """
+{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
+install k-diffusion`
+"""
+
+# docstyle-ignore
+NOTE_SEQ_IMPORT_ERROR = """
+{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip
+install note-seq`
+"""
+
+# docstyle-ignore
+WANDB_IMPORT_ERROR = """
+{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
+install wandb`
+"""
+
+# docstyle-ignore
+OMEGACONF_IMPORT_ERROR = """
+{0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip
+install omegaconf`
+"""
+
+# docstyle-ignore
+TENSORBOARD_IMPORT_ERROR = """
+{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
+install tensorboard`
+"""
+
+
+# docstyle-ignore
+COMPEL_IMPORT_ERROR = """
+{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
+"""
+
+# docstyle-ignore
+BS4_IMPORT_ERROR = """
+{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
+`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+FTFY_IMPORT_ERROR = """
+{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
+installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
+that match your environment. Please note that you may need to restart your runtime after installation.
+"""
+
+# docstyle-ignore
+TORCHSDE_IMPORT_ERROR = """
+{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
+"""
+
+# docstyle-ignore
+INVISIBLE_WATERMARK_IMPORT_ERROR = """
+{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
+ ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
+ ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
+ ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
+ ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
+ ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
+ ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
+ ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
+ ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
+ ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
+ ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
+ ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+ if name in [
+ "VersatileDiffusionTextToImagePipeline",
+ "VersatileDiffusionPipeline",
+ "VersatileDiffusionDualGuidedPipeline",
+ "StableDiffusionImageVariationPipeline",
+ "UnCLIPPipeline",
+ ] and is_transformers_version("<", "4.25.0"):
+ raise ImportError(
+ f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
+ " --upgrade transformers \n```"
+ )
+
+ if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
+ "<", "4.26.0"
+ ):
+ raise ImportError(
+ f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
+ " --upgrade transformers \n```"
+ )
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
+
+
+# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
+def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
+ """
+ Args:
+ Compares a library version to some requirement using a given operation.
+ library_or_version (`str` or `packaging.version.Version`):
+ A library name or a version to check.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`.
+ requirement_version (`str`):
+ The version to compare the library version against
+ """
+ if operation not in STR_OPERATION_TO_FUNC.keys():
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
+ operation = STR_OPERATION_TO_FUNC[operation]
+ if isinstance(library_or_version, str):
+ library_or_version = parse(importlib_metadata.version(library_or_version))
+ return operation(library_or_version, parse(requirement_version))
+
+
+# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
+def is_torch_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current PyTorch version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A string version of PyTorch
+ """
+ return compare_versions(parse(_torch_version), operation, version)
+
+
+def is_transformers_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current Transformers version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _transformers_available:
+ return False
+ return compare_versions(parse(_transformers_version), operation, version)
+
+
+def is_accelerate_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current Accelerate version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _accelerate_available:
+ return False
+ return compare_versions(parse(_accelerate_version), operation, version)
+
+
+def is_k_diffusion_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current k-diffusion version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _k_diffusion_available:
+ return False
+ return compare_versions(parse(_k_diffusion_version), operation, version)
+
+
+class OptionalDependencyNotAvailable(BaseException):
+ """An error indicating that an optional dependency of Diffusers was not found in the environment."""
diff --git a/diffusers/utils/logging.py b/diffusers/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ccc57cd69d57e9bd999e35320cb98416f000522
--- /dev/null
+++ b/diffusers/utils/logging.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+# Copyright 2023 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import (
+ CRITICAL, # NOQA
+ DEBUG, # NOQA
+ ERROR, # NOQA
+ FATAL, # NOQA
+ INFO, # NOQA
+ NOTSET, # NOQA
+ WARN, # NOQA
+ WARNING, # NOQA
+)
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Diffusers' root logger as an `int`.
+
+ Returns:
+ `int`:
+ Logging level integers which can be one of:
+
+ - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `40`: `diffusers.logging.ERROR`
+ - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `20`: `diffusers.logging.INFO`
+ - `10`: `diffusers.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Diffusers' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level which can be one of:
+
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `diffusers.logging.ERROR`
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `diffusers.logging.INFO`
+ - `diffusers.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the 🤗 Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the 🤗 Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for 🤗 Diffusers' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/diffusers/utils/model_card_template.md b/diffusers/utils/model_card_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21
--- /dev/null
+++ b/diffusers/utils/model_card_template.md
@@ -0,0 +1,50 @@
+---
+{{ card_data }}
+---
+
+
+
+# {{ model_name | default("Diffusion Model") }}
+
+## Model description
+
+This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
+on the `{{ dataset_name }}` dataset.
+
+## Intended uses & limitations
+
+#### How to use
+
+```python
+# TODO: add an example code snippet for running this diffusion pipeline
+```
+
+#### Limitations and bias
+
+[TODO: provide examples of latent issues and potential remediations]
+
+## Training data
+
+[TODO: describe the data used to train the model]
+
+### Training hyperparameters
+
+The following hyperparameters were used during training:
+- learning_rate: {{ learning_rate }}
+- train_batch_size: {{ train_batch_size }}
+- eval_batch_size: {{ eval_batch_size }}
+- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
+- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
+- lr_scheduler: {{ lr_scheduler }}
+- lr_warmup_steps: {{ lr_warmup_steps }}
+- ema_inv_gamma: {{ ema_inv_gamma }}
+- ema_inv_gamma: {{ ema_power }}
+- ema_inv_gamma: {{ ema_max_decay }}
+- mixed_precision: {{ mixed_precision }}
+
+### Training results
+
+📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
+
+
diff --git a/diffusers/utils/outputs.py b/diffusers/utils/outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b11561d1e1ee5d5cb40c7630b132e1f451c5b0
--- /dev/null
+++ b/diffusers/utils/outputs.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any, Tuple
+
+import numpy as np
+
+from .import_utils import is_torch_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
+ """
+ if is_torch_available():
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ Python dictionary.
+
+
+
+ You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ first.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
+
+ if other_fields_are_none and isinstance(first_field, dict):
+ for key, value in first_field.items():
+ self[key] = value
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = dict(self.items())
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
diff --git a/diffusers/utils/pil_utils.py b/diffusers/utils/pil_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..15b97c73dcb7f85b22fcae95c641dde0123b5f05
--- /dev/null
+++ b/diffusers/utils/pil_utils.py
@@ -0,0 +1,48 @@
+import PIL.Image
+import PIL.ImageOps
+from packaging import version
+from PIL import Image
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+
+
+def pt_to_pil(images):
+ """
+ Convert a torch image to a PIL image.
+ """
+ images = (images / 2 + 0.5).clamp(0, 1)
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
+ images = numpy_to_pil(images)
+ return images
+
+
+def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
diff --git a/diffusers/utils/testing_utils.py b/diffusers/utils/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..64eb3ac925e9240d30766547880c5dea2e0aeb43
--- /dev/null
+++ b/diffusers/utils/testing_utils.py
@@ -0,0 +1,602 @@
+import inspect
+import logging
+import multiprocessing
+import os
+import random
+import re
+import tempfile
+import unittest
+import urllib.parse
+from distutils.util import strtobool
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import PIL.ImageOps
+import requests
+from packaging import version
+
+from .import_utils import (
+ BACKENDS_MAPPING,
+ is_compel_available,
+ is_flax_available,
+ is_note_seq_available,
+ is_onnx_available,
+ is_opencv_available,
+ is_torch_available,
+ is_torch_version,
+ is_torchsde_available,
+)
+from .logging import get_logger
+
+
+global_rng = random.Random()
+
+logger = get_logger(__name__)
+
+if is_torch_available():
+ import torch
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ:
+ torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
+
+ available_backends = ["cuda", "cpu", "mps"]
+ if torch_device not in available_backends:
+ raise ValueError(
+ f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:"
+ f" {available_backends}"
+ )
+ logger.info(f"torch_device overrode to {torch_device}")
+ else:
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ is_torch_higher_equal_than_1_12 = version.parse(
+ version.parse(torch.__version__).base_version
+ ) >= version.parse("1.12")
+
+ if is_torch_higher_equal_than_1_12:
+ # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
+ mps_backend_registered = hasattr(torch.backends, "mps")
+ torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
+
+
+def torch_all_close(a, b, *args, **kwargs):
+ if not is_torch_available():
+ raise ValueError("PyTorch needs to be installed to use this function.")
+ if not torch.allclose(a, b, *args, **kwargs):
+ assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
+ return True
+
+
+def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
+ test_name = os.environ.get("PYTEST_CURRENT_TEST")
+ if not torch.is_tensor(tensor):
+ tensor = torch.from_numpy(tensor)
+
+ tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
+ # format is usually:
+ # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
+ output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
+ test_file, test_class, test_fn = test_name.split("::")
+ test_fn = test_fn.split()[0]
+ with open(filename, "a") as f:
+ print(";".join([test_file, test_class, test_fn, output_str]), file=f)
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def nightly(test_case):
+ """
+ Decorator marking a test that runs nightly in the diffusers CI.
+
+ Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
+ """
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_torch_2(test_case):
+ """
+ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
+ """
+ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
+ test_case
+ )
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
+ test_case
+ )
+
+
+def skip_mps(test_case):
+ """Decorator marking a test to skip if torch_device is 'mps'"""
+ return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
+
+
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
+def require_compel(test_case):
+ """
+ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
+ the library is not installed.
+ """
+ return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
+
+
+def require_onnxruntime(test_case):
+ """
+ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
+ """
+ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
+
+
+def require_note_seq(test_case):
+ """
+ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
+ """
+ return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
+
+
+def require_torchsde(test_case):
+ """
+ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
+ """
+ return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
+
+
+def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
+ if isinstance(arry, str):
+ # local_path = "/home/patrick_huggingface_co/"
+ if local_path is not None:
+ # local_path can be passed to correct images of tests
+ return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]]))
+ elif arry.startswith("http://") or arry.startswith("https://"):
+ response = requests.get(arry)
+ response.raise_for_status()
+ arry = np.load(BytesIO(response.content))
+ elif os.path.isfile(arry):
+ arry = np.load(arry)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
+ )
+ elif isinstance(arry, np.ndarray):
+ pass
+ else:
+ raise ValueError(
+ "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
+ " ndarray."
+ )
+
+ return arry
+
+
+def load_pt(url: str):
+ response = requests.get(url)
+ response.raise_for_status()
+ arry = torch.load(BytesIO(response.content))
+ return arry
+
+
+def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Loads `image` to a PIL Image.
+
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`:
+ A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def preprocess_image(image: PIL.Image, batch_size: int):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
+ if output_gif_path is None:
+ output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
+
+ image[0].save(
+ output_gif_path,
+ save_all=True,
+ append_images=image[1:],
+ optimize=False,
+ duration=100,
+ loop=0,
+ )
+ return output_gif_path
+
+
+def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
+ if is_opencv_available():
+ import cv2
+ else:
+ raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ h, w, c = video_frames[0].shape
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
+ for i in range(len(video_frames)):
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+ return output_video_path
+
+
+def load_hf_numpy(path) -> np.ndarray:
+ if not path.startswith("http://") or path.startswith("https://"):
+ path = os.path.join(
+ "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
+ )
+
+ return load_numpy(path)
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should
+ pytest do internal changes - also it calls default internal methods of terminalreporter which
+ can be hijacked by various `pytest-` plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = "reports"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{id}_{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+ with open(report_files["passes"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+ """
+ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+ Args:
+ test_case (`unittest.TestCase`):
+ The test that will run `target_func`.
+ target_func (`Callable`):
+ The function implementing the actual testing logic.
+ inputs (`dict`, *optional*, defaults to `None`):
+ The inputs that will be passed to `target_func` through an (input) queue.
+ timeout (`int`, *optional*, defaults to `None`):
+ The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+ variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+ """
+ if timeout is None:
+ timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
+
+ start_methohd = "spawn"
+ ctx = multiprocessing.get_context(start_methohd)
+
+ input_queue = ctx.Queue(1)
+ output_queue = ctx.JoinableQueue(1)
+
+ # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
+ input_queue.put(inputs, timeout=timeout)
+
+ process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+ process.start()
+ # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+ # the test to exit properly.
+ try:
+ results = output_queue.get(timeout=timeout)
+ output_queue.task_done()
+ except Exception as e:
+ process.terminate()
+ test_case.fail(e)
+ process.join(timeout=timeout)
+
+ if results["error"] is not None:
+ test_case.fail(f'{results["error"]}')
+
+
+class CaptureLogger:
+ """
+ Args:
+ Context manager to capture `logging` streams
+ logger: 'logging` logger object
+ Returns:
+ The captured output is available via `self.out`
+ Example:
+ ```python
+ >>> from diffusers import logging
+ >>> from diffusers.testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
+
+
+def enable_full_determinism():
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+def disable_full_determinism():
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
+ torch.use_deterministic_algorithms(False)
diff --git a/diffusers/utils/torch_utils.py b/diffusers/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f64bce25e78d5212696f4b06b767d338599670a
--- /dev/null
+++ b/diffusers/utils/torch_utils.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch utilities: Utilities related to PyTorch
+"""
+from typing import List, Optional, Tuple, Union
+
+from . import logging
+from .import_utils import is_torch_available, is_torch_version
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+try:
+ from torch._dynamo import allow_in_graph as maybe_allow_in_graph
+except (ImportError, ModuleNotFoundError):
+
+ def maybe_allow_in_graph(cls):
+ return cls
+
+
+def randn_tensor(
+ shape: Union[Tuple, List],
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
+ device: Optional["torch.device"] = None,
+ dtype: Optional["torch.dtype"] = None,
+ layout: Optional["torch.layout"] = None,
+):
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
+ is always created on the CPU.
+ """
+ # device on which tensor is created defaults to device
+ rand_device = device
+ batch_size = shape[0]
+
+ layout = layout or torch.strided
+ device = device or torch.device("cpu")
+
+ if generator is not None:
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
+ if gen_device_type != device.type and gen_device_type == "cpu":
+ rand_device = "cpu"
+ if device != "mps":
+ logger.info(
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
+ )
+ elif gen_device_type != device.type and gen_device_type == "cuda":
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
+
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
+
+ return latents
+
+
+def is_compiled_module(module):
+ """Check whether the module was compiled with torch.compile()"""
+ if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
+ return False
+ return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
diff --git a/docs/annotator.md b/docs/annotator.md
new file mode 100644
index 0000000000000000000000000000000000000000..5111678ff11407da19a920e0c500ee0480bbb155
--- /dev/null
+++ b/docs/annotator.md
@@ -0,0 +1,49 @@
+# Automatic Annotations
+
+We provide gradio examples to obtain annotations that are aligned to our pretrained production-ready models.
+
+Just run
+
+ python gradio_annotator.py
+
+Since everyone has different habit to organize their datasets, we do not hard code any scripts for batch processing. But "gradio_annotator.py" is written in a super readable way, and modifying it to annotate your images should be easy.
+
+In the gradio UI of "gradio_annotator.py" we have the following interfaces:
+
+### Canny Edge
+
+Be careful about "black edge and white background" or "white edge and black background".
+
+
+
+### HED Edge
+
+Be careful about "black edge and white background" or "white edge and black background".
+
+
+
+### MLSD Edge
+
+Be careful about "black edge and white background" or "white edge and black background".
+
+
+
+### MIDAS Depth and Normal
+
+Be careful about RGB or BGR in normal maps.
+
+
+
+### Openpose
+
+Be careful about RGB or BGR in pose maps.
+
+For our production-ready model, the hand pose option is turned off.
+
+
+
+### Uniformer Segmentation
+
+Be careful about RGB or BGR in segmentation maps.
+
+
diff --git a/docs/faq.md b/docs/faq.md
new file mode 100644
index 0000000000000000000000000000000000000000..07afd7aeacb51cac4c8bac3b601fe23a2842c4d3
--- /dev/null
+++ b/docs/faq.md
@@ -0,0 +1,21 @@
+# FAQs
+
+**Q:** If the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?
+
+**A:** This is wrong. Let us consider a very simple
+
+$$y=wx+b$$
+
+and we have
+
+$$\partial y/\partial w=x, \partial y/\partial x=w, \partial y/\partial b=1$$
+
+and if $w=0$ and $x \neq 0$, then
+
+$$\partial y/\partial w \neq 0, \partial y/\partial x=0, \partial y/\partial b\neq 0$$
+
+which means as long as $x \neq 0$, one gradient descent iteration will make $w$ non-zero. Then
+
+$$\partial y/\partial x\neq 0$$
+
+so that the zero convolutions will progressively become a common conv layer with non-zero weights.
diff --git a/docs/low_vram.md b/docs/low_vram.md
new file mode 100644
index 0000000000000000000000000000000000000000..784964c78d5074ed9d318456d7c35f30a81f04ed
--- /dev/null
+++ b/docs/low_vram.md
@@ -0,0 +1,15 @@
+# Enable Low VRAM Mode
+
+If you are using 8GB GPU card (or if you want larger batch size), please open "config.py", and then set
+
+```python
+save_memory = True
+```
+
+This feature is still being tested - not all graphics cards are guaranteed to succeed.
+
+But it should be neat as I can diffuse at a batch size of 12 now.
+
+(prompt "man")
+
+
diff --git a/docs/train.md b/docs/train.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa773925e28c4100df4bf74f0536d432554db806
--- /dev/null
+++ b/docs/train.md
@@ -0,0 +1,276 @@
+# Train a ControlNet to Control SD
+
+You are here because you want to control SD in your own way, maybe you have an idea for your perfect research project, and you will annotate some data or have already annotated your own dataset automatically or manually. Herein, the control can be anything that can be converted to images, such as edges, keypoints, segments, etc.
+
+Before moving on to your own dataset, we highly recommend to first try the toy dataset, Fill50K, as a sanity check. This will help you get a "feeling" for the training. You will know how long it will take for the model to converge and whether your device will be able to complete the training in an acceptable amount of time. And what it "feels" like when the model converges.
+
+We hope that after you read this page, you will find that training a ControlNet is as easy as (or easier than) training a pix2pix.
+
+## Step 0 - Design your control
+
+Let us take a look at a very simple task to control SD to fill color in circles.
+
+
+
+This is simple: we want to control SD to fill a circle with colors, and the prompt contains some description of our target.
+
+Stable diffusion is trained on billions of images, and it already knows what is "cyan", what is "circle", what is "pink", and what is "background".
+
+But it does not know the meaning of that "Control Image (Source Image)". Our target is to let it know.
+
+## Step 1 - Get a dataset
+
+Just download the Fill50K dataset from [our huggingface page](https://huggingface.co/lllyasviel/ControlNet) (training/fill50k.zip, the file is only 200M!). Make sure that the data is decompressed as
+
+ ControlNet/training/fill50k/prompt.json
+ ControlNet/training/fill50k/source/X.png
+ ControlNet/training/fill50k/target/X.png
+
+In the folder "fill50k/source", you will have 50k images of circle lines.
+
+
+
+In the folder "fill50k/target", you will have 50k images of filled circles.
+
+
+
+In the "fill50k/prompt.json", you will have their filenames and prompts. Each prompt is like "a balabala color circle in some other color background."
+
+
+
+## Step 2 - Load the dataset
+
+Then you need to write a simple script to read this dataset for pytorch. (In fact we have written it for you in "tutorial_dataset.py".)
+
+```python
+import json
+import cv2
+import numpy as np
+
+from torch.utils.data import Dataset
+
+
+class MyDataset(Dataset):
+ def __init__(self):
+ self.data = []
+ with open('./training/fill50k/prompt.json', 'rt') as f:
+ for line in f:
+ self.data.append(json.loads(line))
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+
+ source_filename = item['source']
+ target_filename = item['target']
+ prompt = item['prompt']
+
+ source = cv2.imread('./training/fill50k/' + source_filename)
+ target = cv2.imread('./training/fill50k/' + target_filename)
+
+ # Do not forget that OpenCV read images in BGR order.
+ source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
+ target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
+
+ # Normalize source images to [0, 1].
+ source = source.astype(np.float32) / 255.0
+
+ # Normalize target images to [-1, 1].
+ target = (target.astype(np.float32) / 127.5) - 1.0
+
+ return dict(jpg=target, txt=prompt, hint=source)
+
+```
+
+This will make your dataset into an array-like object in python. You can test this dataset simply by accessing the array, like this
+
+```python
+from tutorial_dataset import MyDataset
+
+dataset = MyDataset()
+print(len(dataset))
+
+item = dataset[1234]
+jpg = item['jpg']
+txt = item['txt']
+hint = item['hint']
+print(txt)
+print(jpg.shape)
+print(hint.shape)
+
+```
+
+The outputs of this simple test on my machine are
+
+ 50000
+ burly wood circle with orange background
+ (512, 512, 3)
+ (512, 512, 3)
+
+And this code is in "tutorial_dataset_test.py".
+
+In this way, the dataset is an array-like object with 50000 items. Each item is a dict with three entry "jpg", "txt", and "hint". The "jpg" is the target image, the "hint" is the control image, and the "txt" is the prompt.
+
+Do not ask us why we use these three names - this is related to the dark history of a library called LDM.
+
+## Step 3 - What SD model do you want to control?
+
+Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main).
+
+(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2.)
+
+Then you need to attach a control net to the SD model. The architecture is
+
+
+
+Note that all weights inside the ControlNet are also copied from SD so that no layer is trained from scratch, and you are still finetuning the entire model.
+
+We provide a simple script for you to achieve this easily. If your SD filename is "./models/v1-5-pruned.ckpt" and you want the script to save the processed model (SD+ControlNet) at location "./models/control_sd15_ini.ckpt", you can just run:
+
+ python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
+
+Or if you are using SD2:
+
+ python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
+
+You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
+
+This is the correct output from my machine:
+
+
+
+## Step 4 - Train!
+
+Happy! We finally come to the most exciting part: training!
+
+The training code in "tutorial_train.py" is actually surprisingly simple:
+
+```python
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+from tutorial_dataset import MyDataset
+from cldm.logger import ImageLogger
+from cldm.model import create_model, load_state_dict
+
+
+# Configs
+resume_path = './models/control_sd15_ini.ckpt'
+batch_size = 4
+logger_freq = 300
+learning_rate = 1e-5
+sd_locked = True
+only_mid_control = False
+
+
+# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict(resume_path, location='cpu'))
+model.learning_rate = learning_rate
+model.sd_locked = sd_locked
+model.only_mid_control = only_mid_control
+
+
+# Misc
+dataset = MyDataset()
+dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
+logger = ImageLogger(batch_frequency=logger_freq)
+trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
+
+
+# Train!
+trainer.fit(model, dataloader)
+
+```
+(or "tutorial_train_sd21.py" if you are using SD2)
+
+Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.
+
+Now, you may take a look at [Pytorch Lightning Official DOC](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#trainer) to find out how to enable many useful features like gradient accumulation, multiple GPU training, accelerated dataset loading, flexible checkpoint saving, etc. All these only need about one line of code. Great!
+
+Note that if you find OOM, perhaps you need to enable [Low VRAM mode](low_vram.md), and perhaps you also need to use smaller batch size and gradient accumulation. Or you may also want to use some “advanced” tricks like sliced attention or xformers. For example:
+
+```python
+# Configs
+batch_size = 1
+
+# Misc
+trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], accumulate_grad_batches=4) # But this will be 4x slower
+```
+
+Note that training with 8 GB laptop GPU is challenging. We will need some GPU memory optimization at least as good as automatic1111’s UI. This may require expert modifications to the code.
+
+### Screenshots
+
+The training is fast. After 4000 steps (batch size 4, learning rate 1e-5, about 50 minutes on PCIE 40G), the results on my machine (in an output folder "image_log") is
+
+Control:
+
+
+
+Prompt:
+
+
+
+Prediction:
+
+
+
+Ground Truth:
+
+
+
+Note that the SD's capability is preserved. Even training on this super aligned dataset, it still draws some random textures and those snow decorations. (Besides, note that the ground truth looks a bit modified because it is converted from SD's latent image.)
+
+Larger batch size and longer training will further improve this. Adequate training will make the filling perfect.
+
+Of course, training SD to fill circles is meaningless, but this is a successful beginning of your story.
+
+Let us work together to control large models more and more.
+
+## Other options
+
+Beyond standard things, we also provide two important parameters "sd_locked" and "only_mid_control" that you need to know.
+
+### only_mid_control
+
+By default, only_mid_control is False. When it is True, you will train the below architecture.
+
+
+
+This can be helpful when your computation power is limited and want to speed up the training, or when you want to facilitate the "global" context learning. Note that sometimes you may pause training, set it to True, resume training, and pause again, and set it again, and resume again.
+
+If your computation device is good, perhaps you do not need this. But I also know some artists are willing to train a model on their laptop for a month - in that case, perhaps this option can be useful.
+
+### sd_locked
+
+By default, sd_locked is True. When it is False, you will train the below architecture.
+
+
+
+This will unlock some layers in SD and you will train them as a whole.
+
+This option is DANGEROUS! If your dataset is not good enough, this may downgrade the capability of your SD model.
+
+However, this option is also very useful when you are training on images with some specific style, or when you are training with special datasets (like medical dataset with X-ray images or geographic datasets with lots of Google Maps). You can understand this as simultaneously training the ControlNet and something like a DreamBooth.
+
+Also, if your dataset is large, you may want to end the training with a few thousands of steps with those layer unlocked. This usually improve the "problem-specific" solutions a little. You may try it yourself to feel the difference.
+
+Also, if you unlock some original layers, you may want a lower learning rate, like 2e-6.
+
+## More Consideration: Sudden Converge Phenomenon and Gradient Accumulation
+
+
+
+Because we use zero convolutions, the SD should always be able to predict meaningful images. (If it cannot, the training has already failed.)
+
+You will always find that at some iterations, the model "suddenly" be able to fit some training conditions. This means that you will get a basically usable model at about 3k to 7k steps (future training will improve it, but that model after the first "sudden converge" should be basically functional).
+
+Note that 3k to 7k steps is not very large, and you should consider larger batch size rather than more training steps. If you can observe the "sudden converge" at 3k step using batch size 4, then, rather than train it with 300k further steps, a better idea is to use 100× gradient accumulation to re-train that 3k steps with 100× batch size. Note that perhaps we should not do this *too* extremely (perhaps 100x accumulation is too extreme), but you should consider that, since "sudden converge" will *always* happen at that certain point, getting a better converge is more important.
+
+Because that "sudden converge" always happens, lets say "sudden converge" will happen at 3k step and our money can optimize 90k step, then we have two options: (1) train 3k steps, sudden converge, then train 87k steps. (2) 30x gradient accumulation, train 3k steps (90k real computation steps), then sudden converge.
+
+In my experiments, (2) is usually better than (1). However, in real cases, perhaps you may need to balance the steps before and after the "sudden converge" on your own to find a balance. The training after "sudden converge" is also important.
+
+But usually, if your logic batch size is already bigger than 256, then further extending the batch size is not very meaningful. In that case, perhaps a better idea is to train more steps. I tried some "common" logic batch size at 64 or 96 or 128 (by gradient accumulation), it seems that many complicated conditions can be solved very well already.
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91463f0fb17145d7649e298819811ac9a21d6b93
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,35 @@
+name: control
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.8.5
+ - pip=20.3
+ - cudatoolkit=11.3
+ - pytorch=1.12.1
+ - torchvision=0.13.1
+ - numpy=1.23.1
+ - pip:
+ - gradio==3.16.2
+ - albumentations==1.3.0
+ - opencv-contrib-python==4.3.0.36
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.5.0
+ - omegaconf==2.1.1
+ - test-tube>=0.7.5
+ - streamlit==1.12.1
+ - einops==0.3.0
+ - transformers==4.19.2
+ - webdataset==0.2.5
+ - kornia==0.6
+ - open_clip_torch==2.0.2
+ - invisible-watermark>=0.1.5
+ - streamlit-drawable-canvas==0.8.0
+ - torchmetrics==0.6.0
+ - timm==0.6.12
+ - addict==2.4.0
+ - yapf==0.32.0
+ - prettytable==3.6.0
+ - safetensors==0.2.7
+ - basicsr==1.4.2
diff --git a/font/DejaVuSans.ttf b/font/DejaVuSans.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..e5f7eecce43be41ff0703ed99e1553029b849f14
Binary files /dev/null and b/font/DejaVuSans.ttf differ
diff --git a/github_page/a1.png b/github_page/a1.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb3312c9db173924d4bf99d13fc8d06086f6cb79
Binary files /dev/null and b/github_page/a1.png differ
diff --git a/github_page/a2.png b/github_page/a2.png
new file mode 100644
index 0000000000000000000000000000000000000000..efa1ee82e5c76ec6b5fe9f19fdddbe48575c2924
Binary files /dev/null and b/github_page/a2.png differ
diff --git a/github_page/a3.png b/github_page/a3.png
new file mode 100644
index 0000000000000000000000000000000000000000..521a093b08d587b8dec4f7f3d7fbd2db32ee2c43
Binary files /dev/null and b/github_page/a3.png differ
diff --git a/github_page/a4.png b/github_page/a4.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9fc6ee817d79370ddce9d7e57a7f9a63010df73
Binary files /dev/null and b/github_page/a4.png differ
diff --git a/github_page/a5.png b/github_page/a5.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f4fdddb30a3dedb1b2f194ca8009c2704313cd2
Binary files /dev/null and b/github_page/a5.png differ
diff --git a/github_page/a6.png b/github_page/a6.png
new file mode 100644
index 0000000000000000000000000000000000000000..75555245336b3ba6de7d6abe554008e9da199b40
Binary files /dev/null and b/github_page/a6.png differ
diff --git a/github_page/control.pdf b/github_page/control.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a4eb41d88516a815097da3a8e2e0b99d0dca6794
--- /dev/null
+++ b/github_page/control.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed468ba5e8e20f250a6949643398c18987dd43e48ebb5f1d6b81a134ec10f90e
+size 21869744
diff --git a/github_page/ex1.jpg b/github_page/ex1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..89df766784b0e62481ca90d713db1574f8c7990f
Binary files /dev/null and b/github_page/ex1.jpg differ
diff --git a/github_page/he.png b/github_page/he.png
new file mode 100644
index 0000000000000000000000000000000000000000..7af7f18ecbedd3449f2dd4f92c825bf2a2c9c8f4
Binary files /dev/null and b/github_page/he.png differ
diff --git a/github_page/multi.png b/github_page/multi.png
new file mode 100644
index 0000000000000000000000000000000000000000..b7e8c6103c82d11137ec646588fa899dc66c2ea4
Binary files /dev/null and b/github_page/multi.png differ
diff --git a/github_page/multi2.png b/github_page/multi2.png
new file mode 100644
index 0000000000000000000000000000000000000000..10d6b17655a828ec1165feca89193d9cb5566057
Binary files /dev/null and b/github_page/multi2.png differ
diff --git a/github_page/p1.png b/github_page/p1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9f8cc8e27977f146ea44b6d2fe1757acaa4e091
Binary files /dev/null and b/github_page/p1.png differ
diff --git a/github_page/p10.png b/github_page/p10.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca8d5f6175bc6bcb7f251b056b08858813470dd1
--- /dev/null
+++ b/github_page/p10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea30b36f988db5e787e3d061a1745bc94a2991bb76e38325459e79326ae449be
+size 1558223
diff --git a/github_page/p11.png b/github_page/p11.png
new file mode 100644
index 0000000000000000000000000000000000000000..eb0ecfc2f81a4a65ce3422c5c968e2222bed8824
--- /dev/null
+++ b/github_page/p11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2bba5160851111c6758fd0e34fa543369deb443ed6530b93b84d547384a19aef
+size 1025272
diff --git a/github_page/p12.png b/github_page/p12.png
new file mode 100644
index 0000000000000000000000000000000000000000..1fd4bc0005eff38ec6d90e9b8dc36eba975bffba
--- /dev/null
+++ b/github_page/p12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a609ddf0e88c4efbb5ba324935a521f57de94a6ebbc2b24353a110a2cdfbef3f
+size 1037234
diff --git a/github_page/p13.png b/github_page/p13.png
new file mode 100644
index 0000000000000000000000000000000000000000..8ac6bc4400bf71c7d709bfcda0e049f8a4f13115
--- /dev/null
+++ b/github_page/p13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5381027bb60cecb3d20a5a33c614c2bbf313090384e0f3b20117ac58d926062
+size 1140361
diff --git a/github_page/p14.png b/github_page/p14.png
new file mode 100644
index 0000000000000000000000000000000000000000..88c1834789cfe1eb67acb395a227e330ff946217
--- /dev/null
+++ b/github_page/p14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8db031dce5ce4be02cc4d3f82fedd5fd4ac1a595c7fc8c9ba36978d475315aad
+size 1398887
diff --git a/github_page/p15.png b/github_page/p15.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e4b8e4152a326f2bf6cff4a7f95611007a4c5a0
--- /dev/null
+++ b/github_page/p15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:748a7879c839e329ba0921aa4015753c5879811cee2976dcb6636c42fee674db
+size 1366530
diff --git a/github_page/p16b.png b/github_page/p16b.png
new file mode 100644
index 0000000000000000000000000000000000000000..6556de32514b3c5c57c238ee59577071450e210a
--- /dev/null
+++ b/github_page/p16b.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d07b59462d44e9d442a208bb0d457cf7e5ba41059fd1ebd5ca30d23e982e430
+size 1399940
diff --git a/github_page/p17.png b/github_page/p17.png
new file mode 100644
index 0000000000000000000000000000000000000000..de60192873a6e7c1963c7350121b56260390c405
--- /dev/null
+++ b/github_page/p17.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b39ffbc341b431121ac82ae8cd235afbcd7c5e30f57eef58f00d4bb8dcbda6af
+size 1446761
diff --git a/github_page/p18.png b/github_page/p18.png
new file mode 100644
index 0000000000000000000000000000000000000000..3278afee8800111bec6c246515de76086b8864cd
--- /dev/null
+++ b/github_page/p18.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:063bea46c7d9eb7757d5b5674e0cb84f054c96472a38553d4358ad8ab86358b8
+size 1057593
diff --git a/github_page/p19.png b/github_page/p19.png
new file mode 100644
index 0000000000000000000000000000000000000000..3648956ee0370a748e1f2adfdc62b7182a8a568a
--- /dev/null
+++ b/github_page/p19.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:502a2242ad1703e47a5be9398fbd301d3c287f21b85d92f5be314477673b62ac
+size 1113879
diff --git a/github_page/p2.png b/github_page/p2.png
new file mode 100644
index 0000000000000000000000000000000000000000..de2c6353ff834c82bc63ebbfacc8954d9e8ebec7
--- /dev/null
+++ b/github_page/p2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:150a4298489a26844f44b51e365b07c454102f83f571133326c1121b1a939f40
+size 1318561
diff --git a/github_page/p20.png b/github_page/p20.png
new file mode 100644
index 0000000000000000000000000000000000000000..13b9c645331b0ee8d0f9a69122f26074caa6246e
--- /dev/null
+++ b/github_page/p20.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b0511aaba8b1dfcfcc87a684c0c1f30339d88519aba475be356d0dc30090deb
+size 1341061
diff --git a/github_page/p21.png b/github_page/p21.png
new file mode 100644
index 0000000000000000000000000000000000000000..14bdbd5ff8e4c159ebb54b787306f3147948346d
--- /dev/null
+++ b/github_page/p21.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:709cf6a47616a292c72a03fe046d6158a3004cd8f972223f24a867328908a6a6
+size 1810899
diff --git a/github_page/p3.png b/github_page/p3.png
new file mode 100644
index 0000000000000000000000000000000000000000..927efb593523cb1554f1c736ae9e539b5cd6b130
--- /dev/null
+++ b/github_page/p3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5de3f7e26f159237f15439c7dc4b13b85f8f2214b4ec5d3577018bc4272e27b
+size 1203738
diff --git a/github_page/p4.png b/github_page/p4.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d5e845b046771398f0d2fa3d8d475f5df92fcfc
--- /dev/null
+++ b/github_page/p4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a038d132732106731ce899e6d924a34bbcf19d89318b768548683478c1a42590
+size 1159029
diff --git a/github_page/p5.png b/github_page/p5.png
new file mode 100644
index 0000000000000000000000000000000000000000..aaa151beb0abfbd87b46b0d663ada50ae4af1f81
--- /dev/null
+++ b/github_page/p5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0584509963f00842a98a42b4edcd9c84c0310753e6e8861979f87d7f0f27fc2
+size 1268753
diff --git a/github_page/p6.png b/github_page/p6.png
new file mode 100644
index 0000000000000000000000000000000000000000..77eb22c8e7762ede1786196bcb7634fc522ed128
--- /dev/null
+++ b/github_page/p6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4258437186ecc95d4ec63728afdc43af902b1afaf3719e5cccba2fdfc755c0c4
+size 1454815
diff --git a/github_page/p7.png b/github_page/p7.png
new file mode 100644
index 0000000000000000000000000000000000000000..3241d45777fe8d6416cb97e051e74a203d34d0f2
--- /dev/null
+++ b/github_page/p7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:557e0f50015a0a681a29eb7d9747a1a251c87196eb4bdfa0bfb4cbede97d3123
+size 1671023
diff --git a/github_page/p8.png b/github_page/p8.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab72d0dde930fbd3b173f3efc3e4a64532e22693
--- /dev/null
+++ b/github_page/p8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d53b77b5dc70ddc5adc0a3ffa354cefdfd422637bf6c6c48e99f7f2c2598851
+size 1555201
diff --git a/github_page/p9.png b/github_page/p9.png
new file mode 100644
index 0000000000000000000000000000000000000000..03e1233f888a74f9edbd440136174cad2d2909d2
--- /dev/null
+++ b/github_page/p9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcf3b888b15e2d469b204659e9bcb9d8c12b7d5ca9be4411fb3b25c2e958c074
+size 1428057
diff --git a/github_page/ram12.jpg b/github_page/ram12.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..208026eda25f97d09dcf80a020a16e5de1e77230
Binary files /dev/null and b/github_page/ram12.jpg differ
diff --git a/github_page/sd.png b/github_page/sd.png
new file mode 100644
index 0000000000000000000000000000000000000000..2eb3a8567e5278cae0d7a29e65e965d09cba4797
Binary files /dev/null and b/github_page/sd.png differ
diff --git a/github_page/t/gt.png b/github_page/t/gt.png
new file mode 100644
index 0000000000000000000000000000000000000000..df81bcec5be141763af561dded39e766a7de137d
Binary files /dev/null and b/github_page/t/gt.png differ
diff --git a/github_page/t/ip.png b/github_page/t/ip.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc980a22d1e14903d9bd556b9a6c21fbfecbcc63
Binary files /dev/null and b/github_page/t/ip.png differ
diff --git a/github_page/t/op.png b/github_page/t/op.png
new file mode 100644
index 0000000000000000000000000000000000000000..0e51e87cec204f1697e755c8c644f53c7f3dd8a4
--- /dev/null
+++ b/github_page/t/op.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9833659efcb7353b150896028bf70b6123ca132b658ee65d74d2b40e233aa6b
+size 1644715
diff --git a/github_page/t/t.png b/github_page/t/t.png
new file mode 100644
index 0000000000000000000000000000000000000000..404fad5d89abb07da21686e27b5cd2603381fdfc
Binary files /dev/null and b/github_page/t/t.png differ
diff --git a/github_page/t1.png b/github_page/t1.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e6387a279cc48dce605c3732a08bb30360b6116
Binary files /dev/null and b/github_page/t1.png differ
diff --git a/github_page/t2.png b/github_page/t2.png
new file mode 100644
index 0000000000000000000000000000000000000000..a914399475f0d5ba585e0b2c4e565a4b6a4fe42a
Binary files /dev/null and b/github_page/t2.png differ
diff --git a/github_page/t3.png b/github_page/t3.png
new file mode 100644
index 0000000000000000000000000000000000000000..5d27510f96216187b6c08277dacf9cc5b7ab6094
Binary files /dev/null and b/github_page/t3.png differ
diff --git a/github_page/t4.png b/github_page/t4.png
new file mode 100644
index 0000000000000000000000000000000000000000..accc7f76d921a64360c98c59f575332dab2202ea
Binary files /dev/null and b/github_page/t4.png differ
diff --git a/github_page/t5.png b/github_page/t5.png
new file mode 100644
index 0000000000000000000000000000000000000000..714af4b811a7934ca6eda46788aff660a7569bd5
Binary files /dev/null and b/github_page/t5.png differ
diff --git a/github_page/t6.png b/github_page/t6.png
new file mode 100644
index 0000000000000000000000000000000000000000..db8c1d4ce649e130b4fb83c6412efa15ecc71b42
Binary files /dev/null and b/github_page/t6.png differ
diff --git a/github_page/t7.png b/github_page/t7.png
new file mode 100644
index 0000000000000000000000000000000000000000..43a3dac046c25fd18ee1ce89389a16bae4633b3d
Binary files /dev/null and b/github_page/t7.png differ
diff --git a/github_page/uc1.png b/github_page/uc1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e3b874e0497103be64127586aeb89a9f4b261b6d
Binary files /dev/null and b/github_page/uc1.png differ
diff --git a/github_page/uc2a.png b/github_page/uc2a.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a6861226a641b1952445587bf0cec4441fa8167
--- /dev/null
+++ b/github_page/uc2a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96d60afb5cb0975f9b21e7b88aa506c92d2a0fcfcac82aeb58f94d3624355c64
+size 1100385
diff --git a/github_page/uc2b.png b/github_page/uc2b.png
new file mode 100644
index 0000000000000000000000000000000000000000..715008a3e93f093d655e4891bf1ec7a4d24d586e
--- /dev/null
+++ b/github_page/uc2b.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73a23761f4843978657c65d62380fec01641f590d69f36dcb73267e9a28f4544
+size 1555138
diff --git a/github_page/uc3.png b/github_page/uc3.png
new file mode 100644
index 0000000000000000000000000000000000000000..c177def0fc8d375d21d7e355a2fe378009bb53e5
--- /dev/null
+++ b/github_page/uc3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56d2ce31584612d38a6bb94af473f261fc29a6013fa9da60f77fa52857f34e53
+size 1333667
diff --git a/github_page/uc4.png b/github_page/uc4.png
new file mode 100644
index 0000000000000000000000000000000000000000..b662265088fcbc69affd9ca31dcec5b88499e52e
--- /dev/null
+++ b/github_page/uc4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a6b6f035ddfcef6b96c3e869113a436f11affa8d3cfef2ff7c7050a856b0fe4
+size 1294259
diff --git a/github_page/uc6.png b/github_page/uc6.png
new file mode 100644
index 0000000000000000000000000000000000000000..e3009ffc3309c21d987b392f91f1eda0038349a2
--- /dev/null
+++ b/github_page/uc6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:905fd75afb29e6148e3f962e55620483a17120d6999d0b2bcae092511465157e
+size 1830853
diff --git a/github_page/uci1.png b/github_page/uci1.png
new file mode 100644
index 0000000000000000000000000000000000000000..5e34aad358862bbabc775f624d152318dc10e5ee
--- /dev/null
+++ b/github_page/uci1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb6335f749ac28fd1f6639c7bb2fa03b21eb6a0794ac6b3cc1df6c3b592c09e7
+size 1091153
diff --git a/github_page/uci2.png b/github_page/uci2.png
new file mode 100644
index 0000000000000000000000000000000000000000..27088cc6fd9899c43fbdf9d9e70a3e8f94917b4d
--- /dev/null
+++ b/github_page/uci2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57ad677ff9af49ed2dbe75732af7e9fc9d1b357293f978d94fbb99278a84b500
+size 15100883
diff --git a/github_page/uci3.png b/github_page/uci3.png
new file mode 100644
index 0000000000000000000000000000000000000000..01d980b8206e2ed09afedc46c249c0ee5606242e
--- /dev/null
+++ b/github_page/uci3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ec2c3e416d4ac25b2ec3b82ef08924244fbfff0e7794f193a34fe121d07ed2e
+size 1462584
diff --git a/github_page/uci4.png b/github_page/uci4.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a83f6c4fa2849c8b4f103a2883cceb01518cefc
--- /dev/null
+++ b/github_page/uci4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee895a826b206893f7c0cac56bae3442021062df5f0b7eeb7c1bf72419711fb5
+size 7914066
diff --git a/gradio_annotator.py b/gradio_annotator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b1a29ebbec24073a9e4357b700e0577a17a9379
--- /dev/null
+++ b/gradio_annotator.py
@@ -0,0 +1,160 @@
+import gradio as gr
+
+from annotator.util import resize_image, HWC3
+
+
+model_canny = None
+
+
+def canny(img, res, l, h):
+ img = resize_image(HWC3(img), res)
+ global model_canny
+ if model_canny is None:
+ from annotator.canny import CannyDetector
+ model_canny = CannyDetector()
+ result = model_canny(img, l, h)
+ return [result]
+
+
+model_hed = None
+
+
+def hed(img, res):
+ img = resize_image(HWC3(img), res)
+ global model_hed
+ if model_hed is None:
+ from annotator.hed import HEDdetector
+ model_hed = HEDdetector()
+ result = model_hed(img)
+ return [result]
+
+
+model_mlsd = None
+
+
+def mlsd(img, res, thr_v, thr_d):
+ img = resize_image(HWC3(img), res)
+ global model_mlsd
+ if model_mlsd is None:
+ from annotator.mlsd import MLSDdetector
+ model_mlsd = MLSDdetector()
+ result = model_mlsd(img, thr_v, thr_d)
+ return [result]
+
+
+model_midas = None
+
+
+def midas(img, res, a):
+ img = resize_image(HWC3(img), res)
+ global model_midas
+ if model_midas is None:
+ from annotator.midas import MidasDetector
+ model_midas = MidasDetector()
+ results = model_midas(img, a)
+ return results
+
+
+model_openpose = None
+
+
+def openpose(img, res, has_hand):
+ img = resize_image(HWC3(img), res)
+ global model_openpose
+ if model_openpose is None:
+ from annotator.openpose import OpenposeDetector
+ model_openpose = OpenposeDetector()
+ result, _ = model_openpose(img, has_hand)
+ return [result]
+
+
+model_uniformer = None
+
+
+def uniformer(img, res):
+ img = resize_image(HWC3(img), res)
+ global model_uniformer
+ if model_uniformer is None:
+ from annotator.uniformer import UniformerDetector
+ model_uniformer = UniformerDetector()
+ result = model_uniformer(img)
+ return [result]
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Canny Edge")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ low_threshold = gr.Slider(label="low_threshold", minimum=1, maximum=255, value=100, step=1)
+ high_threshold = gr.Slider(label="high_threshold", minimum=1, maximum=255, value=200, step=1)
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=canny, inputs=[input_image, resolution, low_threshold, high_threshold], outputs=[gallery])
+
+ with gr.Row():
+ gr.Markdown("## HED Edge")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=hed, inputs=[input_image, resolution], outputs=[gallery])
+
+ with gr.Row():
+ gr.Markdown("## MLSD Edge")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ value_threshold = gr.Slider(label="value_threshold", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
+ distance_threshold = gr.Slider(label="distance_threshold", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=mlsd, inputs=[input_image, resolution, value_threshold, distance_threshold], outputs=[gallery])
+
+ with gr.Row():
+ gr.Markdown("## MIDAS Depth and Normal")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ alpha = gr.Slider(label="alpha", minimum=0.1, maximum=20.0, value=6.2, step=0.01)
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=384, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=midas, inputs=[input_image, resolution, alpha], outputs=[gallery])
+
+ with gr.Row():
+ gr.Markdown("## Openpose")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ hand = gr.Checkbox(label='detect hand', value=False)
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=openpose, inputs=[input_image, resolution, hand], outputs=[gallery])
+
+
+ with gr.Row():
+ gr.Markdown("## Uniformer Segmentation")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ resolution = gr.Slider(label="resolution", minimum=256, maximum=1024, value=512, step=64)
+ run_button = gr.Button(label="Run")
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(height="auto")
+ run_button.click(fn=uniformer, inputs=[input_image, resolution], outputs=[gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_canny2image.py b/gradio_canny2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..9866cac5b35925576c20ef4b9ee8b1b1cca235b2
--- /dev/null
+++ b/gradio_canny2image.py
@@ -0,0 +1,97 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.canny import CannyDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_canny = CannyDetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
+ with torch.no_grad():
+ img = resize_image(HWC3(input_image), image_resolution)
+ H, W, C = img.shape
+
+ detected_map = apply_canny(img, low_threshold, high_threshold)
+ detected_map = HWC3(detected_map)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [255 - detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_depth2image.py b/gradio_depth2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee678999ae6033c18a5026bc5f6286d0364c7851
--- /dev/null
+++ b/gradio_depth2image.py
@@ -0,0 +1,98 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.midas import MidasDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_midas = MidasDetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Depth Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_fake_scribble2image.py b/gradio_fake_scribble2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7cd375f7589c3f7c43b7df91802eb4bf87ea0e0
--- /dev/null
+++ b/gradio_fake_scribble2image.py
@@ -0,0 +1,102 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.hed import HEDdetector, nms
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_hed = HEDdetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
+ detected_map = nms(detected_map, 127, 3.0)
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
+ detected_map[detected_map > 4] = 255
+ detected_map[detected_map < 255] = 0
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [255 - detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Fake Scribble Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_hed2image.py b/gradio_hed2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ceff67969b7c64a0adcf0557f922c71dd4bfab7
--- /dev/null
+++ b/gradio_hed2image.py
@@ -0,0 +1,98 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.hed import HEDdetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_hed = HEDdetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map = apply_hed(resize_image(input_image, detect_resolution))
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with HED Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_hough2image.py b/gradio_hough2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..6095eeb6767e005a155ee72057b3537021b09f31
--- /dev/null
+++ b/gradio_hough2image.py
@@ -0,0 +1,100 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.mlsd import MLSDdetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_mlsd = MLSDdetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [255 - cv2.dilate(detected_map, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Hough Line Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="Hough Resolution", minimum=128, maximum=1024, value=512, step=1)
+ value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
+ distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_humanpose2image.py b/gradio_humanpose2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3e715e68439d6688531cd4c840c2d512092f09
--- /dev/null
+++ b/gradio_humanpose2image.py
@@ -0,0 +1,115 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.openpose import OpenposeDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+from aagenerator import Generator
+
+# apply_openpose = OpenposeDetector()
+# model = create_model('./models/cldm_v15.yaml').cpu()
+# model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda'))
+# model = model.cuda()
+# ddim_sampler = DDIMSampler(model)
+
+model = Generator()
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
+ guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control],
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control],
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
+ [strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0,
+ 255).astype(
+ np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+def human_gen(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
+ guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ seed_everything(seed)
+ body_numpy, rgb, midas_depth_image, normal_image, rgb2 = model.run(input_image, prompt, steps=ddim_steps)
+ return [body_numpy, rgb, midas_depth_image, normal_image, rgb2]
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Human Pose")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2,
+ height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps,
+ guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=human_gen, inputs=ips, outputs=[result_gallery])
+
+block.launch(server_name='0.0.0.0',
+ share=True)
diff --git a/gradio_normal2image.py b/gradio_normal2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..30aea2f8d4a7956a609cd003f2fee23d2ab162b5
--- /dev/null
+++ b/gradio_normal2image.py
@@ -0,0 +1,99 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.midas import MidasDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_midas = MidasDetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ _, detected_map = apply_midas(resize_image(input_image, detect_resolution), bg_th=bg_threshold)
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
+
+ control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Normal Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="Normal Resolution", minimum=128, maximum=1024, value=384, step=1)
+ bg_threshold = gr.Slider(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_pose2image.py b/gradio_pose2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f0a0b285b5a9ecdc77c09fe7705c98038b89fb
--- /dev/null
+++ b/gradio_pose2image.py
@@ -0,0 +1,99 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.openpose import OpenposeDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_openpose = OpenposeDetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
+ detected_map = HWC3(detected_map)
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Human Pose")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0',
+ share=True)
diff --git a/gradio_scribble2image.py b/gradio_scribble2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abbc25bdeaeae23b9032101336682657825ead4
--- /dev/null
+++ b/gradio_scribble2image.py
@@ -0,0 +1,92 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ img = resize_image(HWC3(input_image), image_resolution)
+ H, W, C = img.shape
+
+ detected_map = np.zeros_like(img, dtype=np.uint8)
+ detected_map[np.min(img, axis=2) < 127] = 255
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [255 - detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Scribble Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_scribble2image_interactive.py b/gradio_scribble2image_interactive.py
new file mode 100644
index 0000000000000000000000000000000000000000..7308bcc1bb8387bba10c026495e0dcddae91c2db
--- /dev/null
+++ b/gradio_scribble2image_interactive.py
@@ -0,0 +1,102 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ img = resize_image(HWC3(input_image['mask'][:, :, 0]), image_resolution)
+ H, W, C = img.shape
+
+ detected_map = np.zeros_like(img, dtype=np.uint8)
+ detected_map[np.min(img, axis=2) > 127] = 255
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [255 - detected_map] + results
+
+
+def create_canvas(w, h):
+ return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Interactive Scribbles")
+ with gr.Row():
+ with gr.Column():
+ canvas_width = gr.Slider(label="Canvas Width", minimum=256, maximum=1024, value=512, step=1)
+ canvas_height = gr.Slider(label="Canvas Height", minimum=256, maximum=1024, value=512, step=1)
+ create_button = gr.Button(label="Start", value='Open drawing canvas!')
+ input_image = gr.Image(source='upload', type='numpy', tool='sketch')
+ gr.Markdown(value='Do not forget to change your brush width to make it thinner. (Gradio do not allow developers to set brush width so you need to do it manually.) '
+ 'Just click on the small pencil icon in the upper right corner of the above block.')
+ create_button.click(fn=create_canvas, inputs=[canvas_width, canvas_height], outputs=[input_image])
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/gradio_seg2image.py b/gradio_seg2image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3854dc7624ed6a0a68f059c5001e4973da27587
--- /dev/null
+++ b/gradio_seg2image.py
@@ -0,0 +1,97 @@
+from share import *
+import config
+
+import cv2
+import einops
+import gradio as gr
+import numpy as np
+import torch
+import random
+
+from pytorch_lightning import seed_everything
+from annotator.util import resize_image, HWC3
+from annotator.uniformer import UniformerDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+
+
+apply_uniformer = UniformerDetector()
+
+model = create_model('./models/cldm_v15.yaml').cpu()
+model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda'))
+model = model.cuda()
+ddim_sampler = DDIMSampler(model)
+
+
+def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
+ with torch.no_grad():
+ input_image = HWC3(input_image)
+ detected_map = apply_uniformer(resize_image(input_image, detect_resolution))
+ img = resize_image(input_image, image_resolution)
+ H, W, C = img.shape
+
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
+
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
+ shape, cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+ return [detected_map] + results
+
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Control Stable Diffusion with Segmentation Maps")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="numpy")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
+ detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
+ n_prompt = gr.Textbox(label="Negative Prompt",
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
+ with gr.Column():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
+
+
+block.launch(server_name='0.0.0.0')
diff --git a/ldm/__pycache__/util.cpython-38.pyc b/ldm/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4fbd012942e95aad302686d8b38b4a0134636d5
Binary files /dev/null and b/ldm/__pycache__/util.cpython-38.pyc differ
diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/util.py b/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+ def __init__(self, model_type):
+ super().__init__()
+ self.transform = load_midas_transform(model_type)
+
+ def pt2np(self, x):
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
+ return x
+
+ def np2pt(self, x):
+ x = torch.from_numpy(x) * 2 - 1.
+ return x
+
+ def __call__(self, sample):
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ x = self.pt2np(sample['jpg'])
+ x = self.transform({"image": x})["image"]
+ sample['midas_in'] = x
+ return sample
\ No newline at end of file
diff --git a/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/ldm/models/__pycache__/autoencoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3827a9a7006f41c9d8d1bd67e29793fa194c0e1
Binary files /dev/null and b/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0192327003ad4bcc6ba051dcdbaea00f264e3bd0
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ae14ea3c76ec37cc66ac10e2938544b0ea82b7f
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1626f50bb5dd0a2f78de3e34e76081cd9155939
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,336 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71a44af48c8cba8e97849b7e6813b3e6f9fe83c
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1797 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.register_buffer('logvar', logvar)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args, **kwargs):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError('TODO')
+
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args, **kwargs
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = rearrange(args[0]["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+ return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, 'b h w c -> b c h w')
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+ return log
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
\ No newline at end of file
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/__pycache__/attention.cpython-38.pyc b/ldm/modules/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6da2545015189f2d77c57041caa7fe84c5824666
Binary files /dev/null and b/ldm/modules/__pycache__/attention.cpython-38.pyc differ
diff --git a/ldm/modules/__pycache__/ema.cpython-38.pyc b/ldm/modules/__pycache__/ema.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b581ba6d0acce1423cac6b3b0171b24cf7478b7
Binary files /dev/null and b/ldm/modules/__pycache__/ema.cpython-38.pyc differ
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,341 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION =="fp32":
+ with torch.autocast(enabled=False, device_type = 'cuda'):
+ q, k = q.float(), k.float()
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d11401223222187d9c370c479e3d9b1272dd0258
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d888102240952a6a380097c1c9f327624fb73cca
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46774977022f56ac1bc2222580d2e8729c6dfef2
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0abac2732293e18274a596c96e0b7caa4b8eac9c
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+ print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None):
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b51b9d91fcd3318267c855b1507b9b535a9e0e0
Binary files /dev/null and b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b39f71ae0a37a7de2d72fb480eb17c7c738389e7
Binary files /dev/null and b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6e087c8a8d6e656f65dcd65fc6573935d7a339c
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7a310f70dc206567864d948e6d0823cb6339100
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import open_clip
+from ldm.util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
+
diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cb050ece6f401a22dde098ce3f1ff663c5eb6a
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
\ No newline at end of file
diff --git a/mmpose/__init__.py b/mmpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad7946470df9a9e16830af885aeee31e3aaee6ca
--- /dev/null
+++ b/mmpose/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import mmengine
+from mmengine.utils import digit_version
+
+from .version import __version__, short_version
+
+mmcv_minimum_version = '2.0.0rc4'
+mmcv_maximum_version = '2.1.0'
+mmcv_version = digit_version(mmcv.__version__)
+
+mmengine_minimum_version = '0.6.0'
+mmengine_maximum_version = '1.0.0'
+mmengine_version = digit_version(mmengine.__version__)
+
+assert (mmcv_version >= digit_version(mmcv_minimum_version)
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
+
+assert (mmengine_version >= digit_version(mmengine_minimum_version)
+ and mmengine_version <= digit_version(mmengine_maximum_version)), \
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
+ f'<={mmengine_maximum_version}.'
+
+__all__ = ['__version__', 'short_version']
diff --git a/mmpose/__pycache__/__init__.cpython-38.pyc b/mmpose/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6684280e5e0d985451799dc57d98757885fde04e
Binary files /dev/null and b/mmpose/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/__pycache__/registry.cpython-38.pyc b/mmpose/__pycache__/registry.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f95cca270581e9bf04dc43bfc6ada4b51d0f8b5c
Binary files /dev/null and b/mmpose/__pycache__/registry.cpython-38.pyc differ
diff --git a/mmpose/__pycache__/version.cpython-38.pyc b/mmpose/__pycache__/version.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3185f2085a525eefb6a80732305e29e842939e24
Binary files /dev/null and b/mmpose/__pycache__/version.cpython-38.pyc differ
diff --git a/mmpose/apis/__init__.py b/mmpose/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff7149e453552aefe6c3beb35404f281df15943a
--- /dev/null
+++ b/mmpose/apis/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .inference import inference_bottomup, inference_topdown, init_model
+from .inferencers import MMPoseInferencer, Pose2DInferencer
+
+__all__ = [
+ 'init_model', 'inference_topdown', 'inference_bottomup',
+ 'Pose2DInferencer', 'MMPoseInferencer'
+]
diff --git a/mmpose/apis/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbe08892135200eb55616826d7f1ed87fb0745e7
Binary files /dev/null and b/mmpose/apis/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/apis/__pycache__/inference.cpython-38.pyc b/mmpose/apis/__pycache__/inference.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e21aeea460cb62e28037714dc3d0f4362b7a0598
Binary files /dev/null and b/mmpose/apis/__pycache__/inference.cpython-38.pyc differ
diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6763d318d5d7fef3f267c235ce6f8210ac6078ce
--- /dev/null
+++ b/mmpose/apis/inference.py
@@ -0,0 +1,225 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from pathlib import Path
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmengine.config import Config
+from mmengine.dataset import Compose, pseudo_collate
+from mmengine.model.utils import revert_sync_batchnorm
+from mmengine.registry import init_default_scope
+from mmengine.runner import load_checkpoint
+from PIL import Image
+
+from mmpose.datasets.datasets.utils import parse_pose_metainfo
+from mmpose.models.builder import build_pose_estimator
+from mmpose.structures import PoseDataSample
+from mmpose.structures.bbox import bbox_xywh2xyxy
+
+
+def dataset_meta_from_config(config: Config,
+ dataset_mode: str = 'train') -> Optional[dict]:
+ """Get dataset metainfo from the model config.
+
+ Args:
+ config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
+ :obj:`Path`, or the config object.
+ dataset_mode (str): Specify the dataset of which to get the metainfo.
+ Options are ``'train'``, ``'val'`` and ``'test'``. Defaults to
+ ``'train'``
+
+ Returns:
+ dict, optional: The dataset metainfo. See
+ ``mmpose.datasets.datasets.utils.parse_pose_metainfo`` for details.
+ Return ``None`` if failing to get dataset metainfo from the config.
+ """
+ try:
+ if dataset_mode == 'train':
+ dataset_cfg = config.train_dataloader.dataset
+ elif dataset_mode == 'val':
+ dataset_cfg = config.val_dataloader.dataset
+ elif dataset_mode == 'test':
+ dataset_cfg = config.test_dataloader.dataset
+ else:
+ raise ValueError(
+ f'Invalid dataset {dataset_mode} to get metainfo. '
+ 'Should be one of "train", "val", or "test".')
+
+ if 'metainfo' in dataset_cfg:
+ metainfo = dataset_cfg.metainfo
+ else:
+ import mmpose.datasets.datasets # noqa: F401, F403
+ from mmpose.registry import DATASETS
+
+ dataset_class = DATASETS.get(dataset_cfg.type)
+ metainfo = dataset_class.METAINFO
+
+ metainfo = parse_pose_metainfo(metainfo)
+
+ except AttributeError:
+ metainfo = None
+
+ return metainfo
+
+
+def init_model(config: Union[str, Path, Config],
+ checkpoint: Optional[str] = None,
+ device: str = 'cuda:0',
+ cfg_options: Optional[dict] = None) -> nn.Module:
+ """Initialize a pose estimator from a config file.
+
+ Args:
+ config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
+ :obj:`Path`, or the config object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights. Defaults to ``None``
+ device (str): The device where the anchors will be put on.
+ Defaults to ``'cuda:0'``.
+ cfg_options (dict, optional): Options to override some settings in
+ the used config. Defaults to ``None``
+
+ Returns:
+ nn.Module: The constructed pose estimator.
+ """
+
+ if isinstance(config, (str, Path)):
+ config = Config.fromfile(config)
+ elif not isinstance(config, Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(config)}')
+ if cfg_options is not None:
+ config.merge_from_dict(cfg_options)
+ elif 'init_cfg' in config.model.backbone:
+ config.model.backbone.init_cfg = None
+ config.model.train_cfg = None
+
+ # register all modules in mmpose into the registries
+ init_default_scope(config.get('default_scope', 'mmpose'))
+
+ model = build_pose_estimator(config.model)
+ model = revert_sync_batchnorm(model)
+ # get dataset_meta in this priority: checkpoint > config > default (COCO)
+ dataset_meta = None
+
+ if checkpoint is not None:
+ ckpt = load_checkpoint(model, checkpoint, map_location='cpu')
+
+ if 'dataset_meta' in ckpt.get('meta', {}):
+ # checkpoint from mmpose 1.x
+ dataset_meta = ckpt['meta']['dataset_meta']
+
+ if dataset_meta is None:
+ dataset_meta = dataset_meta_from_config(config, dataset_mode='train')
+
+ if dataset_meta is None:
+ warnings.simplefilter('once')
+ warnings.warn('Can not load dataset_meta from the checkpoint or the '
+ 'model config. Use COCO metainfo by default.')
+ dataset_meta = parse_pose_metainfo(
+ dict(from_file='configs/_base_/datasets/coco.py'))
+
+ model.dataset_meta = dataset_meta
+
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+def inference_topdown(model: nn.Module,
+ img: Union[np.ndarray, str],
+ bboxes: Optional[Union[List, np.ndarray]] = None,
+ bbox_format: str = 'xyxy') -> List[PoseDataSample]:
+ """Inference image with a top-down pose estimator.
+
+ Args:
+ model (nn.Module): The top-down pose estimator
+ img (np.ndarray | str): The loaded image or image file to inference
+ bboxes (np.ndarray, optional): The bboxes in shape (N, 4), each row
+ represents a bbox. If not given, the entire image will be regarded
+ as a single bbox area. Defaults to ``None``
+ bbox_format (str): The bbox format indicator. Options are ``'xywh'``
+ and ``'xyxy'``. Defaults to ``'xyxy'``
+
+ Returns:
+ List[:obj:`PoseDataSample`]: The inference results. Specifically, the
+ predicted keypoints and scores are saved at
+ ``data_sample.pred_instances.keypoints`` and
+ ``data_sample.pred_instances.keypoint_scores``.
+ """
+ init_default_scope(model.cfg.get('default_scope', 'mmpose'))
+ pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
+
+ if bboxes is None:
+ # get bbox from the image size
+ if isinstance(img, str):
+ w, h = Image.open(img).size
+ else:
+ h, w = img.shape[:2]
+
+ bboxes = np.array([[0, 0, w, h]], dtype=np.float32)
+ else:
+ if isinstance(bboxes, list):
+ bboxes = np.array(bboxes)
+
+ assert bbox_format in {'xyxy', 'xywh'}, \
+ f'Invalid bbox_format "{bbox_format}".'
+
+ if bbox_format == 'xywh':
+ bboxes = bbox_xywh2xyxy(bboxes)
+
+ # construct batch data samples
+ data_list = []
+ for bbox in bboxes:
+ if isinstance(img, str):
+ data_info = dict(img_path=img)
+ else:
+ data_info = dict(img=img)
+ data_info['bbox'] = bbox[None] # shape (1, 4)
+ data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,)
+ data_info.update(model.dataset_meta)
+ data_list.append(pipeline(data_info))
+
+ if data_list:
+ # collate data list into a batch, which is a dict with following keys:
+ # batch['inputs']: a list of input images
+ # batch['data_samples']: a list of :obj:`PoseDataSample`
+ batch = pseudo_collate(data_list)
+ with torch.no_grad():
+ results = model.test_step(batch)
+ else:
+ results = []
+
+ return results
+
+
+def inference_bottomup(model: nn.Module, img: Union[np.ndarray, str]):
+ """Inference image with a bottom-up pose estimator.
+
+ Args:
+ model (nn.Module): The bottom-up pose estimator
+ img (np.ndarray | str): The loaded image or image file to inference
+
+ Returns:
+ List[:obj:`PoseDataSample`]: The inference results. Specifically, the
+ predicted keypoints and scores are saved at
+ ``data_sample.pred_instances.keypoints`` and
+ ``data_sample.pred_instances.keypoint_scores``.
+ """
+ pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
+
+ # prepare data batch
+ if isinstance(img, str):
+ data_info = dict(img_path=img)
+ else:
+ data_info = dict(img=img)
+ data_info.update(model.dataset_meta)
+ data = pipeline(data_info)
+ batch = pseudo_collate([data])
+
+ with torch.no_grad():
+ results = model.test_step(batch)
+
+ return results
diff --git a/mmpose/apis/inferencers/__init__.py b/mmpose/apis/inferencers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db192da738d98acca3739a501701024cf2fcb02
--- /dev/null
+++ b/mmpose/apis/inferencers/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .mmpose_inferencer import MMPoseInferencer
+from .pose2d_inferencer import Pose2DInferencer
+from .utils import get_model_aliases
+
+__all__ = ['Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases']
diff --git a/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9086cfd4833008cfd13339e4bc768f6ba56c48d
Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..931c890f2b84376e361bc6018ac62f7242ddeadb
Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/base_mmpose_inferencer.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7fd9067ddaf07bf67e8700b0ba359643cfd93c45
Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/mmpose_inferencer.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc b/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..936ad40424e6b4b0361d85063b13b1fc7bd9c411
Binary files /dev/null and b/mmpose/apis/inferencers/__pycache__/pose2d_inferencer.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py
new file mode 100644
index 0000000000000000000000000000000000000000..86e61463b698f7e01942d1b204f92f975ff472d1
--- /dev/null
+++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py
@@ -0,0 +1,435 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mimetypes
+import os
+import warnings
+from collections import defaultdict
+from typing import (Callable, Dict, Generator, Iterable, List, Optional,
+ Sequence, Union)
+
+import cv2
+import mmcv
+import mmengine
+import numpy as np
+import torch.nn as nn
+from mmengine.config import Config, ConfigDict
+from mmengine.dataset import Compose
+from mmengine.fileio import (get_file_backend, isdir, join_path,
+ list_dir_or_file)
+from mmengine.infer.infer import BaseInferencer
+from mmengine.runner.checkpoint import _load_checkpoint_to_model
+from mmengine.structures import InstanceData
+from mmengine.utils import mkdir_or_exist
+
+from mmpose.apis.inference import dataset_meta_from_config
+from mmpose.structures import PoseDataSample, split_instances
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+class BaseMMPoseInferencer(BaseInferencer):
+ """The base class for MMPose inferencers."""
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'draw_bbox',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def _load_weights_to_model(self, model: nn.Module,
+ checkpoint: Optional[dict],
+ cfg: Optional[ConfigType]) -> None:
+ """Loading model weights and meta information from cfg and checkpoint.
+
+ Subclasses could override this method to load extra meta information
+ from ``checkpoint`` and ``cfg`` to model.
+
+ Args:
+ model (nn.Module): Model to load weights and meta information.
+ checkpoint (dict, optional): The loaded checkpoint.
+ cfg (Config or ConfigDict, optional): The loaded config.
+ """
+ if checkpoint is not None:
+ _load_checkpoint_to_model(model, checkpoint)
+ checkpoint_meta = checkpoint.get('meta', {})
+ # save the dataset_meta in the model for convenience
+ if 'dataset_meta' in checkpoint_meta:
+ # mmpose 1.x
+ model.dataset_meta = checkpoint_meta['dataset_meta']
+ else:
+ warnings.warn(
+ 'dataset_meta are not saved in the checkpoint\'s '
+ 'meta data, load via config.')
+ model.dataset_meta = dataset_meta_from_config(
+ cfg, dataset_mode='train')
+ else:
+ warnings.warn('Checkpoint is not loaded, and the inference '
+ 'result is calculated by the randomly initialized '
+ 'model!')
+ model.dataset_meta = dataset_meta_from_config(
+ cfg, dataset_mode='train')
+
+ def _inputs_to_list(self, inputs: InputsType) -> Iterable:
+ """Preprocess the inputs to a list.
+
+ Preprocess inputs to a list according to its type:
+
+ - list or tuple: return inputs
+ - str:
+ - Directory path: return all files in the directory
+ - other cases: return a list containing the string. The string
+ could be a path to file, a url or other types of string
+ according to the task.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+
+ Returns:
+ list: List of input for the :meth:`preprocess`.
+ """
+ self._video_input = False
+
+ if isinstance(inputs, str):
+ backend = get_file_backend(inputs)
+ if hasattr(backend, 'isdir') and isdir(inputs):
+ # Backends like HttpsBackend do not implement `isdir`, so only
+ # those backends that implement `isdir` could accept the
+ # inputs as a directory
+ filepath_list = [
+ join_path(inputs, fname)
+ for fname in list_dir_or_file(inputs, list_dir=False)
+ ]
+ inputs = []
+ for filepath in filepath_list:
+ input_type = mimetypes.guess_type(filepath)[0].split(
+ '/')[0]
+ if input_type == 'image':
+ inputs.append(filepath)
+ inputs.sort()
+ else:
+ # if inputs is a path to a video file, it will be converted
+ # to a list containing separated frame filenames
+ input_type = mimetypes.guess_type(inputs)[0].split('/')[0]
+ if input_type == 'video':
+ self._video_input = True
+ video = mmcv.VideoReader(inputs)
+ self.video_info = dict(
+ fps=video.fps,
+ name=os.path.basename(inputs),
+ writer=None,
+ predictions=[])
+ inputs = video
+ elif input_type == 'image':
+ inputs = [inputs]
+ else:
+ raise ValueError(f'Expected input to be an image, video, '
+ f'or folder, but received {inputs} of '
+ f'type {input_type}.')
+
+ elif isinstance(inputs, np.ndarray):
+ inputs = [inputs]
+
+ return inputs
+
+ def _get_webcam_inputs(self, inputs: str) -> Generator:
+ """Sets up and returns a generator function that reads frames from a
+ webcam input. The generator function returns a new frame each time it
+ is iterated over.
+
+ Args:
+ inputs (str): A string describing the webcam input, in the format
+ "webcam:id".
+
+ Returns:
+ A generator function that yields frames from the webcam input.
+
+ Raises:
+ ValueError: If the inputs string is not in the expected format.
+ """
+ assert getattr(self.visualizer, 'backend', None) == 'opencv', \
+ 'Visualizer must utilize the OpenCV backend in order to ' \
+ 'support webcam inputs.'
+
+ # Ensure the inputs string is in the expected format.
+ inputs = inputs.lower()
+ assert inputs.startswith('webcam'), f'Expected input to start with ' \
+ f'"webcam", but got "{inputs}"'
+
+ # Parse the camera ID from the inputs string.
+ inputs_ = inputs.split(':')
+ if len(inputs_) == 1:
+ camera_id = 0
+ elif len(inputs_) == 2 and str.isdigit(inputs_[1]):
+ camera_id = int(inputs_[1])
+ else:
+ raise ValueError(
+ f'Expected webcam input to have format "webcam:id", '
+ f'but got "{inputs}"')
+
+ # Attempt to open the video capture object.
+ vcap = cv2.VideoCapture(camera_id)
+ if not vcap.isOpened():
+ warnings.warn(f'Cannot open camera (ID={camera_id})')
+ return []
+
+ # Set video input flag and metadata.
+ self._video_input = True
+ self.video_info = dict(
+ fps=10, name='webcam.mp4', writer=None, predictions=[])
+
+ def _webcam_reader() -> Generator:
+ while True:
+ if cv2.waitKey(5) & 0xFF == 27:
+ vcap.release()
+ break
+
+ ret_val, frame = vcap.read()
+ if not ret_val:
+ break
+
+ yield frame
+
+ return _webcam_reader()
+
+ def _visualization_window_on_close(self, event):
+ self._window_closing = True
+
+ def _init_pipeline(self, cfg: ConfigType) -> Callable:
+ """Initialize the test pipeline.
+
+ Args:
+ cfg (ConfigType): model config path or dict
+
+ Returns:
+ A pipeline to handle various input data, such as ``str``,
+ ``np.ndarray``. The returned pipeline will be used to process
+ a single data.
+ """
+ return Compose(cfg.test_dataloader.dataset.pipeline)
+
+ def preprocess(self,
+ inputs: InputsType,
+ batch_size: int = 1,
+ bboxes: Optional[List] = None,
+ **kwargs):
+ """Process the inputs into a model-feedable format.
+
+ Args:
+ inputs (InputsType): Inputs given by user.
+ batch_size (int): batch size. Defaults to 1.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ List[str or np.ndarray]: List of original inputs in the batch
+ """
+
+ for i, input in enumerate(inputs):
+ bbox = bboxes[i] if bboxes is not None else []
+ data_infos = self.preprocess_single(
+ input, index=i, bboxes=bbox, **kwargs)
+ # only supports inference with batch size 1
+ yield self.collate_fn(data_infos), [input]
+
+ def visualize(self,
+ inputs: list,
+ preds: List[PoseDataSample],
+ return_vis: bool = False,
+ show: bool = False,
+ draw_bbox: bool = False,
+ wait_time: float = 0,
+ radius: int = 3,
+ thickness: int = 1,
+ kpt_thr: float = 0.3,
+ vis_out_dir: str = '',
+ window_name: str = '',
+ window_close_event_handler: Optional[Callable] = None
+ ) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Args:
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
+ preds (Any): Predictions of the model.
+ return_vis (bool): Whether to return images with predicted results.
+ show (bool): Whether to display the image in a popup window.
+ Defaults to False.
+ wait_time (float): The interval of show (ms). Defaults to 0
+ draw_bbox (bool): Whether to draw the bounding boxes.
+ Defaults to False
+ radius (int): Keypoint radius for visualization. Defaults to 3
+ thickness (int): Link thickness for visualization. Defaults to 1
+ kpt_thr (float): The threshold to visualize the keypoints.
+ Defaults to 0.3
+ vis_out_dir (str, optional): Directory to save visualization
+ results w/o predictions. If left as empty, no file will
+ be saved. Defaults to ''.
+ window_name (str, optional): Title of display window.
+ window_close_event_handler (callable, optional):
+
+ Returns:
+ List[np.ndarray]: Visualization results.
+ """
+ if (not return_vis) and (not show) and (not vis_out_dir):
+ return
+
+ if getattr(self, 'visualizer', None) is None:
+ raise ValueError('Visualization needs the "visualizer" term'
+ 'defined in the config, but got None.')
+
+ self.visualizer.radius = radius
+ self.visualizer.line_width = thickness
+
+ results = []
+
+ for single_input, pred in zip(inputs, preds):
+ if isinstance(single_input, str):
+ img = mmcv.imread(single_input, channel_order='rgb')
+ elif isinstance(single_input, np.ndarray):
+ img = mmcv.bgr2rgb(single_input)
+ else:
+ raise ValueError('Unsupported input type: '
+ f'{type(single_input)}')
+
+ img_name = os.path.basename(pred.metainfo['img_path'])
+ window_name = window_name if window_name else img_name
+
+ # since visualization and inference utilize the same process,
+ # the wait time is reduced when a video input is utilized,
+ # thereby eliminating the issue of inference getting stuck.
+ wait_time = 1e-5 if self._video_input else wait_time
+
+ visualization = self.visualizer.add_datasample(
+ window_name,
+ img,
+ pred,
+ draw_gt=False,
+ draw_bbox=draw_bbox,
+ draw_heatmap=True,
+ show=show,
+ wait_time=wait_time,
+ kpt_thr=kpt_thr)
+ results.append(visualization)
+
+ if vis_out_dir:
+ out_img = mmcv.rgb2bgr(visualization)
+
+ if self._video_input:
+
+ if self.video_info['writer'] is None:
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ mkdir_or_exist(vis_out_dir)
+ out_file = join_path(
+ vis_out_dir,
+ os.path.basename(self.video_info['name']))
+ self.video_info['writer'] = cv2.VideoWriter(
+ out_file, fourcc, self.video_info['fps'],
+ (visualization.shape[1], visualization.shape[0]))
+ self.video_info['writer'].write(out_img)
+
+ else:
+ out_file = join_path(vis_out_dir, img_name)
+ mmcv.imwrite(out_img, out_file)
+
+ if return_vis:
+ return results
+ else:
+ return []
+
+ def postprocess(
+ self,
+ preds: List[PoseDataSample],
+ visualization: List[np.ndarray],
+ return_datasample=False,
+ pred_out_dir: str = '',
+ ) -> dict:
+ """Process the predictions and visualization results from ``forward``
+ and ``visualize``.
+
+ This method should be responsible for the following tasks:
+
+ 1. Convert datasamples into a json-serializable dict if needed.
+ 2. Pack the predictions and visualization results and return them.
+ 3. Dump or log the predictions.
+
+ Args:
+ preds (List[Dict]): Predictions of the model.
+ visualization (np.ndarray): Visualized predictions.
+ return_datasample (bool): Whether to return results as
+ datasamples. Defaults to False.
+ pred_out_dir (str): Directory to save the inference results w/o
+ visualization. If left as empty, no file will be saved.
+ Defaults to ''.
+
+ Returns:
+ dict: Inference and visualization results with key ``predictions``
+ and ``visualization``
+
+ - ``visualization (Any)``: Returned by :meth:`visualize`
+ - ``predictions`` (dict or DataSample): Returned by
+ :meth:`forward` and processed in :meth:`postprocess`.
+ If ``return_datasample=False``, it usually should be a
+ json-serializable dict containing only basic data elements such
+ as strings and numbers.
+ """
+
+ result_dict = defaultdict(list)
+
+ result_dict['visualization'] = visualization
+ for pred in preds:
+ if not return_datasample:
+ # convert datasamples to list of instance predictions
+ pred = split_instances(pred.pred_instances)
+ result_dict['predictions'].append(pred)
+
+ if pred_out_dir != '':
+ for pred, data_sample in zip(result_dict['predictions'], preds):
+ if self._video_input:
+ self.video_info['predictions'].append(pred)
+ else:
+ fname = os.path.splitext(
+ os.path.basename(
+ data_sample.metainfo['img_path']))[0] + '.json'
+ mmengine.dump(
+ pred, join_path(pred_out_dir, fname), indent=' ')
+
+ return result_dict
+
+ def _finalize_video_processing(
+ self,
+ pred_out_dir: str = '',
+ ):
+ """Finalize video processing by releasing the video writer and saving
+ predictions to a file.
+
+ This method should be called after completing the video processing. It
+ releases the video writer, if it exists, and saves the predictions to a
+ JSON file if a prediction output directory is provided.
+ """
+
+ # Release the video writer if it exists
+ if self.video_info['writer'] is not None:
+ self.video_info['writer'].release()
+
+ # Save predictions
+ if pred_out_dir:
+ fname = os.path.splitext(
+ os.path.basename(self.video_info['name']))[0] + '.json'
+ predictions = [
+ dict(frame_id=i, instances=pred)
+ for i, pred in enumerate(self.video_info['predictions'])
+ ]
+
+ mmengine.dump(
+ predictions, join_path(pred_out_dir, fname), indent=' ')
diff --git a/mmpose/apis/inferencers/mmpose_inferencer.py b/mmpose/apis/inferencers/mmpose_inferencer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d5ac222bb3a5c37b89531706c8ce2d94f5f55a1
--- /dev/null
+++ b/mmpose/apis/inferencers/mmpose_inferencer.py
@@ -0,0 +1,281 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Dict, List, Optional, Sequence, Union
+
+import numpy as np
+import torch
+from mmengine.config import Config, ConfigDict
+from mmengine.infer.infer import ModelType
+from mmengine.structures import InstanceData
+from rich.progress import track
+
+from mmpose.structures import PoseDataSample
+from .base_mmpose_inferencer import BaseMMPoseInferencer
+from .pose2d_inferencer import Pose2DInferencer
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+class MMPoseInferencer(BaseMMPoseInferencer):
+ """MMPose Inferencer. It's a unified inferencer interface for pose
+ estimation task, currently including: Pose2D. and it can be used to perform
+ 2D keypoint detection.
+
+ Args:
+ pose2d (str, optional): Pretrained 2D pose estimation algorithm.
+ It's the path to the config file or the model name defined in
+ metafile. For example, it could be:
+
+ - model alias, e.g. ``'body'``,
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
+ - config path
+
+ Defaults to ``None``.
+ pose2d_weights (str, optional): Path to the custom checkpoint file of
+ the selected pose2d model. If it is not specified and "pose2d" is
+ a model name of metafile, the weights will be loaded from
+ metafile. Defaults to None.
+ device (str, optional): Device to run inference. If None, the
+ available device will be automatically used. Defaults to None.
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
+ det_model(str, optional): Config path or alias of detection model.
+ Defaults to None.
+ det_weights(str, optional): Path to the checkpoints of detection
+ model. Defaults to None.
+ det_cat_ids(int or list[int], optional): Category id for
+ detection model. Defaults to None.
+ output_heatmaps (bool, optional): Flag to visualize predicted
+ heatmaps. If set to None, the default setting from the model
+ config will be used. Default is None.
+ """
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'draw_bbox',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def __init__(self,
+ pose2d: Optional[str] = None,
+ pose2d_weights: Optional[str] = None,
+ device: Optional[str] = None,
+ scope: str = 'mmpose',
+ det_model: Optional[Union[ModelType, str]] = None,
+ det_weights: Optional[str] = None,
+ det_cat_ids: Optional[Union[int, List]] = None,
+ output_heatmaps: Optional[bool] = None) -> None:
+
+ if pose2d is None:
+ raise ValueError('2d pose estimation algorithm should provided.')
+
+ self.visualizer = None
+ self.inferencers = dict()
+ if pose2d is not None:
+ self.inferencers['pose2d'] = Pose2DInferencer(
+ pose2d, pose2d_weights, device, scope, det_model, det_weights,
+ det_cat_ids, output_heatmaps)
+
+ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
+ """Process the inputs into a model-feedable format.
+
+ Args:
+ inputs (InputsType): Inputs given by user.
+ batch_size (int): batch size. Defaults to 1.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ List[str or np.ndarray]: List of original inputs in the batch
+ """
+
+ for i, input in enumerate(inputs):
+ data_batch = {}
+ if 'pose2d' in self.inferencers:
+ data_infos = self.inferencers['pose2d'].preprocess_single(
+ input, index=i, **kwargs)
+ data_batch['pose2d'] = self.inferencers['pose2d'].collate_fn(
+ data_infos)
+ # only supports inference with batch size 1
+ yield data_batch, [input]
+
+ @torch.no_grad()
+ def forward(self, inputs: InputType, **forward_kwargs) -> PredType:
+ """Forward the inputs to the model.
+
+ Args:
+ inputs (InputsType): The inputs to be forwarded.
+
+ Returns:
+ Dict: The prediction results. Possibly with keys "pose2d".
+ """
+ result = {}
+ for mode, inferencer in self.inferencers.items():
+ result[mode] = inferencer.forward(inputs[mode], **forward_kwargs)
+
+ return result
+
+ def __call__(
+ self,
+ inputs: InputsType,
+ return_datasample: bool = False,
+ batch_size: int = 1,
+ out_dir: Optional[str] = None,
+ **kwargs,
+ ) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+ return_datasample (bool): Whether to return results as
+ :obj:`BaseDataElement`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ out_dir (str, optional): directory to save visualization
+ results and predictions. Will be overoden if vis_out_dir or
+ pred_out_dir are given. Defaults to None
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
+ Each key in kwargs should be in the corresponding set of
+ ``preprocess_kwargs``, ``forward_kwargs``,
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
+
+ Returns:
+ dict: Inference and visualization results.
+ """
+ if out_dir is not None:
+ if 'vis_out_dir' not in kwargs:
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
+ if 'pred_out_dir' not in kwargs:
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
+ (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ ) = self._dispatch_kwargs(**kwargs)
+
+ # preprocessing
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
+ inputs = self._get_webcam_inputs(inputs)
+ batch_size = 1
+ if not visualize_kwargs.get('show', False):
+ warnings.warn('The display mode is closed when using webcam '
+ 'input. It will be turned on automatically.')
+ visualize_kwargs['show'] = True
+ else:
+ inputs = self._inputs_to_list(inputs)
+
+ inputs = self.preprocess(
+ inputs, batch_size=batch_size, **preprocess_kwargs)
+
+ # forward
+ forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
+ for inferencer in self.inferencers.values():
+ inferencer._video_input = self._video_input
+ if self._video_input:
+ inferencer.video_info = self.video_info
+
+ preds = []
+ if 'pose2d' not in self.inferencers or not hasattr(
+ self.inferencers['pose2d'], 'detector'):
+ inputs = track(inputs, description='Inference')
+
+ for proc_inputs, ori_inputs in inputs:
+ preds = self.forward(proc_inputs, **forward_kwargs)
+
+ visualization = self.visualize(ori_inputs, preds,
+ **visualize_kwargs)
+ results = self.postprocess(preds, visualization, return_datasample,
+ **postprocess_kwargs)
+ yield results
+
+ if self._video_input:
+ self._finalize_video_processing(
+ postprocess_kwargs.get('pred_out_dir', ''))
+
+ def visualize(self, inputs: InputsType, preds: PredType,
+ **kwargs) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Args:
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
+ preds (Any): Predictions of the model.
+ return_vis (bool): Whether to return images with predicted results.
+ show (bool): Whether to display the image in a popup window.
+ Defaults to False.
+ show_interval (int): The interval of show (s). Defaults to 0
+ radius (int): Keypoint radius for visualization. Defaults to 3
+ thickness (int): Link thickness for visualization. Defaults to 1
+ kpt_thr (float): The threshold to visualize the keypoints.
+ Defaults to 0.3
+ vis_out_dir (str, optional): directory to save visualization
+ results w/o predictions. If left as empty, no file will
+ be saved. Defaults to ''.
+
+ Returns:
+ List[np.ndarray]: Visualization results.
+ """
+
+ if 'pose2d' in self.inferencers:
+ window_name = ''
+ if self._video_input:
+ window_name = self.video_info['name']
+ return self.inferencers['pose2d'].visualize(
+ inputs,
+ preds['pose2d'],
+ window_name=window_name,
+ window_close_event_handler=self._visualization_window_on_close,
+ **kwargs)
+
+ def postprocess(
+ self,
+ preds: List[PoseDataSample],
+ visualization: List[np.ndarray],
+ return_datasample=False,
+ pred_out_dir: str = '',
+ ) -> dict:
+ """Process the predictions and visualization results from ``forward``
+ and ``visualize``.
+
+ This method should be responsible for the following tasks:
+
+ 1. Convert datasamples into a json-serializable dict if needed.
+ 2. Pack the predictions and visualization results and return them.
+ 3. Dump or log the predictions.
+
+ Args:
+ preds (List[Dict]): Predictions of the model.
+ visualization (np.ndarray): Visualized predictions.
+ return_datasample (bool): Whether to return results as
+ datasamples. Defaults to False.
+ pred_out_dir (str): Directory to save the inference results w/o
+ visualization. If left as empty, no file will be saved.
+ Defaults to ''.
+
+ Returns:
+ dict: Inference and visualization results with key ``predictions``
+ and ``visualization``
+
+ - ``visualization (Any)``: Returned by :meth:`visualize`
+ - ``predictions`` (dict or DataSample): Returned by
+ :meth:`forward` and processed in :meth:`postprocess`.
+ If ``return_datasample=False``, it usually should be a
+ json-serializable dict containing only basic data elements such
+ as strings and numbers.
+ """
+
+ if 'pose2d' in self.inferencers:
+ return super().postprocess(preds['pose2d'], visualization,
+ return_datasample, pred_out_dir)
diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b35abddb19aa35497d5df7da9d2f4c084dd38ecb
--- /dev/null
+++ b/mmpose/apis/inferencers/pose2d_inferencer.py
@@ -0,0 +1,289 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import warnings
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import mmcv
+import numpy as np
+import torch
+from mmengine.config import Config, ConfigDict
+from mmengine.infer.infer import ModelType
+from mmengine.model import revert_sync_batchnorm
+from mmengine.registry import init_default_scope
+from mmengine.structures import InstanceData
+from rich.progress import track
+
+from mmpose.evaluation.functional import nms
+from mmpose.registry import DATASETS, INFERENCERS
+from mmpose.structures import merge_data_samples
+from .base_mmpose_inferencer import BaseMMPoseInferencer
+from .utils import default_det_models
+
+try:
+ from mmdet.apis.det_inferencer import DetInferencer
+ has_mmdet = True
+except (ImportError, ModuleNotFoundError):
+ has_mmdet = False
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+@INFERENCERS.register_module(name='pose-estimation')
+@INFERENCERS.register_module()
+class Pose2DInferencer(BaseMMPoseInferencer):
+ """The inferencer for 2D pose estimation.
+
+ Args:
+ model (str, optional): Pretrained 2D pose estimation algorithm.
+ It's the path to the config file or the model name defined in
+ metafile. For example, it could be:
+
+ - model alias, e.g. ``'body'``,
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
+ - config path
+
+ Defaults to ``None``.
+ weights (str, optional): Path to the checkpoint. If it is not
+ specified and "model" is a model name of metafile, the weights
+ will be loaded from metafile. Defaults to None.
+ device (str, optional): Device to run inference. If None, the
+ available device will be automatically used. Defaults to None.
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
+ det_model (str, optional): Config path or alias of detection model.
+ Defaults to None.
+ det_weights (str, optional): Path to the checkpoints of detection
+ model. Defaults to None.
+ det_cat_ids (int or list[int], optional): Category id for
+ detection model. Defaults to None.
+ output_heatmaps (bool, optional): Flag to visualize predicted
+ heatmaps. If set to None, the default setting from the model
+ config will be used. Default is None.
+ """
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'draw_bbox',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def __init__(self,
+ model: Union[ModelType, str],
+ weights: Optional[str] = None,
+ device: Optional[str] = None,
+ scope: Optional[str] = 'mmpose',
+ det_model: Optional[Union[ModelType, str]] = None,
+ det_weights: Optional[str] = None,
+ det_cat_ids: Optional[Union[int, Tuple]] = None,
+ output_heatmaps: Optional[bool] = None) -> None:
+
+ init_default_scope(scope)
+ super().__init__(
+ model=model, weights=weights, device=device, scope=scope)
+ self.model = revert_sync_batchnorm(self.model)
+ if output_heatmaps is not None:
+ self.model.test_cfg['output_heatmaps'] = output_heatmaps
+
+ # assign dataset metainfo to self.visualizer
+ self.visualizer.set_dataset_meta(self.model.dataset_meta)
+
+ # initialize detector for top-down models
+ if self.cfg.data_mode == 'topdown':
+ object_type = DATASETS.get(self.cfg.dataset_type).__module__.split(
+ 'datasets.')[-1].split('.')[0].lower()
+
+ if det_model in ('whole_image', 'whole-image') or \
+ (det_model is None and
+ object_type not in default_det_models):
+ self.detector = None
+
+ else:
+ det_scope = 'mmdet'
+ if det_model is None:
+ det_info = default_det_models[object_type]
+ det_model, det_weights, det_cat_ids = det_info[
+ 'model'], det_info['weights'], det_info['cat_ids']
+ elif os.path.exists(det_model):
+ det_cfg = Config.fromfile(det_model)
+ det_scope = det_cfg.default_scope
+
+ if has_mmdet:
+ self.detector = DetInferencer(
+ det_model, det_weights, device=device, scope=det_scope)
+ else:
+ raise RuntimeError(
+ 'MMDetection (v3.0.0 or above) is required to build '
+ 'inferencers for top-down pose estimation models.')
+
+ if isinstance(det_cat_ids, (tuple, list)):
+ self.det_cat_ids = det_cat_ids
+ else:
+ self.det_cat_ids = (det_cat_ids, )
+
+ self._video_input = False
+
+ def preprocess_single(self,
+ input: InputType,
+ index: int,
+ bbox_thr: float = 0.3,
+ nms_thr: float = 0.3,
+ bboxes: Union[List[List], List[np.ndarray],
+ np.ndarray] = []):
+ """Process a single input into a model-feedable format.
+
+ Args:
+ input (InputType): Input given by user.
+ index (int): index of the input
+ bbox_thr (float): threshold for bounding box detection.
+ Defaults to 0.3.
+ nms_thr (float): IoU threshold for bounding box NMS.
+ Defaults to 0.3.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ """
+
+ if isinstance(input, str):
+ data_info = dict(img_path=input)
+ else:
+ data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
+ data_info.update(self.model.dataset_meta)
+
+ if self.cfg.data_mode == 'topdown':
+ if self.detector is not None:
+ det_results = self.detector(
+ input, return_datasample=True)['predictions']
+ pred_instance = det_results[0].pred_instances.cpu().numpy()
+ bboxes = np.concatenate(
+ (pred_instance.bboxes, pred_instance.scores[:, None]),
+ axis=1)
+
+ label_mask = np.zeros(len(bboxes), dtype=np.uint8)
+ for cat_id in self.det_cat_ids:
+ label_mask = np.logical_or(label_mask,
+ pred_instance.labels == cat_id)
+
+ bboxes = bboxes[np.logical_and(
+ label_mask, pred_instance.scores > bbox_thr)]
+ bboxes = bboxes[nms(bboxes, nms_thr)]
+
+ data_infos = []
+ if len(bboxes) > 0:
+ for bbox in bboxes:
+ inst = data_info.copy()
+ inst['bbox'] = bbox[None, :4]
+ inst['bbox_score'] = bbox[4:5]
+ data_infos.append(self.pipeline(inst))
+ else:
+ inst = data_info.copy()
+
+ # get bbox from the image size
+ if isinstance(input, str):
+ input = mmcv.imread(input)
+ h, w = input.shape[:2]
+
+ inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
+ inst['bbox_score'] = np.ones(1, dtype=np.float32)
+ data_infos.append(self.pipeline(inst))
+
+ else: # bottom-up
+ data_infos = [self.pipeline(data_info)]
+
+ return data_infos
+
+ @torch.no_grad()
+ def forward(self, inputs: Union[dict, tuple], bbox_thr=-1):
+ data_samples = super().forward(inputs)
+ if self.cfg.data_mode == 'topdown':
+ data_samples = [merge_data_samples(data_samples)]
+ if bbox_thr > 0:
+ for ds in data_samples:
+ if 'bbox_scores' in ds.pred_instances:
+ ds.pred_instances = ds.pred_instances[
+ ds.pred_instances.bbox_scores > bbox_thr]
+ return data_samples
+
+ def __call__(
+ self,
+ inputs: InputsType,
+ return_datasample: bool = False,
+ batch_size: int = 1,
+ out_dir: Optional[str] = None,
+ **kwargs,
+ ) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+ return_datasample (bool): Whether to return results as
+ :obj:`BaseDataElement`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ out_dir (str, optional): directory to save visualization
+ results and predictions. Will be overoden if vis_out_dir or
+ pred_out_dir are given. Defaults to None
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
+ Each key in kwargs should be in the corresponding set of
+ ``preprocess_kwargs``, ``forward_kwargs``,
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
+
+ Returns:
+ dict: Inference and visualization results.
+ """
+ if out_dir is not None:
+ if 'vis_out_dir' not in kwargs:
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
+ if 'pred_out_dir' not in kwargs:
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
+
+ (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ ) = self._dispatch_kwargs(**kwargs)
+
+ # preprocessing
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
+ inputs = self._get_webcam_inputs(inputs)
+ batch_size = 1
+ if not visualize_kwargs.get('show', False):
+ warnings.warn('The display mode is closed when using webcam '
+ 'input. It will be turned on automatically.')
+ visualize_kwargs['show'] = True
+ else:
+ inputs = self._inputs_to_list(inputs)
+
+ forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
+ inputs = self.preprocess(
+ inputs, batch_size=batch_size, **preprocess_kwargs)
+
+ preds = []
+ if not hasattr(self, 'detector'):
+ inputs = track(inputs, description='Inference')
+
+ for proc_inputs, ori_inputs in inputs:
+ preds = self.forward(proc_inputs, **forward_kwargs)
+
+ visualization = self.visualize(ori_inputs, preds,
+ **visualize_kwargs)
+ results = self.postprocess(preds, visualization, return_datasample,
+ **postprocess_kwargs)
+ yield results
+
+ if self._video_input:
+ self._finalize_video_processing(
+ postprocess_kwargs.get('pred_out_dir', ''))
diff --git a/mmpose/apis/inferencers/utils/__init__.py b/mmpose/apis/inferencers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cc40535b0d42a3b2ff41e97e26dcc30c440622b
--- /dev/null
+++ b/mmpose/apis/inferencers/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .default_det_models import default_det_models
+from .get_model_alias import get_model_aliases
+
+__all__ = ['default_det_models', 'get_model_aliases']
diff --git a/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8743eb7e1a7fa72b05cee58b23163c135412f60e
Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..994aa4db0e13b3ffabe4158bfda1592c1ac4c2bb
Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/default_det_models.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc b/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc157a870b424def4cbac881fbc7283aa1a2d0c4
Binary files /dev/null and b/mmpose/apis/inferencers/utils/__pycache__/get_model_alias.cpython-38.pyc differ
diff --git a/mmpose/apis/inferencers/utils/default_det_models.py b/mmpose/apis/inferencers/utils/default_det_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb3076880a313109dbfc72e4d9a0b0e3d6a7059
--- /dev/null
+++ b/mmpose/apis/inferencers/utils/default_det_models.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from mmengine.config.utils import MODULE2PACKAGE
+from mmengine.utils import get_installed_path
+
+mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose'])
+
+default_det_models = dict(
+ human=dict(model='rtmdet-m', weights="/fsx_laion/alvin/pretrain/ViTPose/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth", cat_ids=(0, )),
+ face=dict(
+ model=osp.join(mmpose_path, '.mim',
+ 'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'),
+ weights='/fsx_laion/alvin/pretrain/ViTPose/yolo-x_8xb8-300e_coco-face_13274d7c.pth',
+ cat_ids=(0, )),
+ hand=dict(
+ model=osp.join(
+ mmpose_path, '.mim', 'demo/mmdetection_cfg/'
+ 'ssdlite_mobilenetv2_scratch_600e_onehand.py'),
+ weights='/fsx_laion/alvin/pretrain/ViTPose/ssdlite_mobilenetv2_scratch_600e_onehand-4f9f8686_20220523.pth',
+ cat_ids=(0, )),
+ animal=dict(
+ model='rtmdet-m',
+ weights=None,
+ cat_ids=(15, 16, 17, 18, 19, 20, 21, 22, 23)),
+)
+
+default_det_models['body'] = default_det_models['human']
+default_det_models['wholebody'] = default_det_models['human']
diff --git a/mmpose/apis/inferencers/utils/get_model_alias.py b/mmpose/apis/inferencers/utils/get_model_alias.py
new file mode 100644
index 0000000000000000000000000000000000000000..49de6528d6ea0df58cf7ae987176defbd4953739
--- /dev/null
+++ b/mmpose/apis/inferencers/utils/get_model_alias.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict
+
+from mmengine.infer import BaseInferencer
+
+
+def get_model_aliases(scope: str = 'mmpose') -> Dict[str, str]:
+ """Retrieve model aliases and their corresponding configuration names.
+
+ Args:
+ scope (str, optional): The scope for the model aliases. Defaults
+ to 'mmpose'.
+
+ Returns:
+ Dict[str, str]: A dictionary containing model aliases as keys and
+ their corresponding configuration names as values.
+ """
+
+ # Get a list of model configurations from the metafile
+ repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope)
+ model_cfgs = BaseInferencer._get_models_from_metafile(repo_or_mim_dir)
+
+ model_alias_dict = dict()
+ for model_cfg in model_cfgs:
+ if 'Alias' in model_cfg:
+ if isinstance(model_cfg['Alias'], str):
+ model_alias_dict[model_cfg['Alias']] = model_cfg['Name']
+ elif isinstance(model_cfg['Alias'], list):
+ for alias in model_cfg['Alias']:
+ model_alias_dict[alias] = model_cfg['Name']
+ else:
+ raise ValueError(
+ 'encounter an unexpected alias type. Please raise an '
+ 'issue at https://github.com/open-mmlab/mmpose/issues '
+ 'to announce us')
+
+ return model_alias_dict
diff --git a/mmpose/apis/webcam/__init__.py b/mmpose/apis/webcam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..271b238c674479f7120e9f8ef0b1087cb06b42e4
--- /dev/null
+++ b/mmpose/apis/webcam/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .webcam_executor import WebcamExecutor
+
+__all__ = ['WebcamExecutor']
diff --git a/mmpose/apis/webcam/nodes/__init__.py b/mmpose/apis/webcam/nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..50f7c899d3d5ee2fb06d62cd9fb3963fffbd99de
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_visualizer_node import BaseVisualizerNode
+from .helper_nodes import MonitorNode, ObjectAssignerNode, RecorderNode
+from .model_nodes import DetectorNode, TopdownPoseEstimatorNode
+from .node import Node
+from .registry import NODES
+from .visualizer_nodes import (BigeyeEffectNode, NoticeBoardNode,
+ ObjectVisualizerNode, SunglassesEffectNode)
+
+__all__ = [
+ 'BaseVisualizerNode', 'NODES', 'MonitorNode', 'ObjectAssignerNode',
+ 'RecorderNode', 'DetectorNode', 'TopdownPoseEstimatorNode', 'Node',
+ 'BigeyeEffectNode', 'NoticeBoardNode', 'ObjectVisualizerNode',
+ 'ObjectAssignerNode', 'SunglassesEffectNode'
+]
diff --git a/mmpose/apis/webcam/nodes/base_visualizer_node.py b/mmpose/apis/webcam/nodes/base_visualizer_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e0ba397d4bf31fc44df86714ce828344d1fdb76
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/base_visualizer_node.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ..utils import FrameMessage, Message
+from .node import Node
+
+
+class BaseVisualizerNode(Node):
+ """Base class for nodes whose function is to create visual effects, like
+ visualizing model predictions, showing graphics or showing text messages.
+
+ All subclass should implement the method ``draw()``.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str | list): The name(s) of the output buffer(s).
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ """
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True):
+
+ super().__init__(name=name, enable_key=enable_key, enable=enable)
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', trigger=True)
+ self.register_output_buffer(output_buffer)
+
+ def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ input_msg = input_msgs['input']
+
+ img = self.draw(input_msg)
+ input_msg.set_image(img)
+
+ return input_msg
+
+ def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ return input_msgs['input']
+
+ @abstractmethod
+ def draw(self, input_msg: FrameMessage) -> np.ndarray:
+ """Draw on the frame image of the input FrameMessage.
+
+ Args:
+ input_msg (:obj:`FrameMessage`): The message of the frame to draw
+ on
+
+ Returns:
+ np.array: The processed image.
+ """
diff --git a/mmpose/apis/webcam/nodes/helper_nodes/__init__.py b/mmpose/apis/webcam/nodes/helper_nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb0ed9dd156bfaeb0f77c78b4afa82d877064cf
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/helper_nodes/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .monitor_node import MonitorNode
+from .object_assigner_node import ObjectAssignerNode
+from .recorder_node import RecorderNode
+
+__all__ = ['MonitorNode', 'ObjectAssignerNode', 'RecorderNode']
diff --git a/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py b/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..305490dc52efd314e8319084b9e806e4a3edb01c
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/helper_nodes/monitor_node.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+from mmcv import color_val
+
+from ..node import Node
+from ..registry import NODES
+
+try:
+ import psutil
+ psutil_proc = psutil.Process()
+except (ImportError, ModuleNotFoundError):
+ psutil_proc = None
+
+
+@NODES.register_module()
+class MonitorNode(Node):
+ """Show diagnostic information.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ x_offset (int): The position of the text box's left border in
+ pixels. Default: 20
+ y_offset (int): The position of the text box's top border in
+ pixels. Default: 20
+ y_delta (int): The line height in pixels. Default: 15
+ text_color (str|tuple): The font color represented in a color name or
+ a BGR tuple. Default: ``'black'``
+ backbround_color (str|tuple): The background color represented in a
+ color name or a BGR tuple. Default: (255, 183, 0)
+ text_scale (float): The font scale factor that is multiplied by the
+ base size. Default: 0.4
+ ignore_items (list[str], optional): Specify the node information items
+ that will not be shown. See ``MonitorNode._default_ignore_items``
+ for the default setting.
+
+ Example::
+ >>> cfg = dict(
+ ... type='MonitorNode',
+ ... name='monitor',
+ ... enable_key='m',
+ ... enable=False,
+ ... input_buffer='vis_notice',
+ ... output_buffer='display')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ _default_ignore_items = ['timestamp']
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = False,
+ x_offset=20,
+ y_offset=20,
+ y_delta=15,
+ text_color='black',
+ background_color=(255, 183, 0),
+ text_scale=0.4,
+ ignore_items: Optional[List[str]] = None):
+ super().__init__(name=name, enable_key=enable_key, enable=enable)
+
+ self.x_offset = x_offset
+ self.y_offset = y_offset
+ self.y_delta = y_delta
+ self.text_color = color_val(text_color)
+ self.background_color = color_val(background_color)
+ self.text_scale = text_scale
+ if ignore_items is None:
+ self.ignore_items = self._default_ignore_items
+ else:
+ self.ignore_items = ignore_items
+
+ self.register_input_buffer(input_buffer, 'input', trigger=True)
+ self.register_output_buffer(output_buffer)
+
+ def process(self, input_msgs):
+ input_msg = input_msgs['input']
+
+ input_msg.update_route_info(
+ node_name='System Info',
+ node_type='none',
+ info=self._get_system_info())
+
+ img = input_msg.get_image()
+ route_info = input_msg.get_route_info()
+ img = self._show_route_info(img, route_info)
+
+ input_msg.set_image(img)
+ return input_msg
+
+ def _get_system_info(self):
+ """Get the system information including CPU and memory usage.
+
+ Returns:
+ dict: The system information items.
+ """
+ sys_info = {}
+ if psutil_proc is not None:
+ sys_info['CPU(%)'] = psutil_proc.cpu_percent()
+ sys_info['Memory(%)'] = psutil_proc.memory_percent()
+ return sys_info
+
+ def _show_route_info(self, img: np.ndarray,
+ route_info: List[Dict]) -> np.ndarray:
+ """Show the route information in the frame.
+
+ Args:
+ img (np.ndarray): The frame image.
+ route_info (list[dict]): The route information of the frame.
+
+ Returns:
+ np.ndarray: The processed image.
+ """
+ canvas = np.full(img.shape, self.background_color, dtype=img.dtype)
+
+ x = self.x_offset
+ y = self.y_offset
+
+ max_len = 0
+
+ def _put_line(line=''):
+ nonlocal y, max_len
+ cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX,
+ self.text_scale, self.text_color, 1)
+ y += self.y_delta
+ max_len = max(max_len, len(line))
+
+ for node_info in route_info:
+ title = f'{node_info["node"]}({node_info["node_type"]})'
+ _put_line(title)
+ for k, v in node_info['info'].items():
+ if k in self.ignore_items:
+ continue
+ if isinstance(v, float):
+ v = f'{v:.1f}'
+ _put_line(f' {k}: {v}')
+
+ x1 = max(0, self.x_offset)
+ x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20))
+ y1 = max(0, self.y_offset - self.y_delta)
+ y2 = min(img.shape[0], y)
+
+ src1 = canvas[y1:y2, x1:x2]
+ src2 = img[y1:y2, x1:x2]
+ img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0)
+
+ return img
+
+ def bypass(self, input_msgs):
+ return input_msgs['input']
diff --git a/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py b/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a7804ab426fd27f29055aa3d15e0ffdfa359cd
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/helper_nodes/object_assigner_node.py
@@ -0,0 +1,139 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from typing import List, Union
+
+from mmpose.utils.timer import RunningAverage
+from ..node import Node
+from ..registry import NODES
+
+
+@NODES.register_module()
+class ObjectAssignerNode(Node):
+ """Assign the object information to the frame message.
+
+ :class:`ObjectAssignerNode` enables asynchronous processing of model
+ inference and video I/O, so the video will be captured and displayed
+ smoothly regardless of the model inference speed. Specifically,
+ :class:`ObjectAssignerNode` takes messages from both model branch and
+ video I/O branch as its input, indicated as "object message" and "frame
+ message" respectively. When an object message arrives it will update the
+ latest object information; and when a frame message arrives, it will be
+ assigned with the latest object information and output.
+
+ Specially, if the webcam executor is set to synchrounous mode, the
+ behavior of :class:`ObjectAssignerNode` will be different: When an object
+ message arrives, it will trigger an output of itself; and the frame
+ messages will be ignored.
+
+ Args:
+ name (str): The node name (also thread name)
+ frame_buffer (str): Buffer name for frame messages
+ object_buffer (str): Buffer name for object messages
+ output_buffer (str): The name(s) of the output buffer(s)
+
+ Example::
+ >>> cfg =dict(
+ ... type='ObjectAssignerNode',
+ ... name='object assigner',
+ ... frame_buffer='_frame_',
+ ... # `_frame_` is an executor-reserved buffer
+ ... object_buffer='animal_pose',
+ ... output_buffer='frame')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ def __init__(self, name: str, frame_buffer: str, object_buffer: str,
+ output_buffer: Union[str, List[str]]):
+ super().__init__(name=name, enable=True)
+ self.synchronous = None
+
+ # Cache the latest model result
+ self.last_object_msg = None
+ self.last_output_msg = None
+
+ # Inference speed analysis
+ self.frame_fps = RunningAverage(window=10)
+ self.frame_lag = RunningAverage(window=10)
+ self.object_fps = RunningAverage(window=10)
+ self.object_lag = RunningAverage(window=10)
+
+ # Register buffers
+ # The trigger buffer depends on the executor.synchronous attribute,
+ # so it will be set later after the executor is assigned in
+ # ``set_executor``.
+ self.register_input_buffer(object_buffer, 'object', trigger=False)
+ self.register_input_buffer(frame_buffer, 'frame', trigger=False)
+ self.register_output_buffer(output_buffer)
+
+ def set_executor(self, executor):
+ super().set_executor(executor)
+ # Set synchronous according to the executor
+ if executor.synchronous:
+ self.synchronous = True
+ trigger = 'object'
+ else:
+ self.synchronous = False
+ trigger = 'frame'
+
+ # Set trigger input buffer according to the synchronous setting
+ for buffer_info in self._input_buffers:
+ if buffer_info.input_name == trigger:
+ buffer_info.trigger = True
+
+ def process(self, input_msgs):
+ object_msg = input_msgs['object']
+
+ # Update last result
+ if object_msg is not None:
+ # Update result FPS
+ if self.last_object_msg is not None:
+ self.object_fps.update(
+ 1.0 /
+ (object_msg.timestamp - self.last_object_msg.timestamp))
+ # Update inference latency
+ self.object_lag.update(time.time() - object_msg.timestamp)
+ # Update last inference result
+ self.last_object_msg = object_msg
+
+ if not self.synchronous:
+ # Asynchronous mode:
+ # Assign the latest object information to the
+ # current frame.
+ frame_msg = input_msgs['frame']
+
+ self.frame_lag.update(time.time() - frame_msg.timestamp)
+
+ # Assign objects to frame
+ if self.last_object_msg is not None:
+ frame_msg.update_objects(self.last_object_msg.get_objects())
+ frame_msg.merge_route_info(
+ self.last_object_msg.get_route_info())
+
+ output_msg = frame_msg
+
+ else:
+ # Synchronous mode:
+ # The current frame will be ignored. Instead,
+ # the frame from which the latest object information is obtained
+ # will be used.
+ self.frame_lag.update(time.time() - object_msg.timestamp)
+ output_msg = object_msg
+
+ # Update frame fps and lag
+ if self.last_output_msg is not None:
+ self.frame_lag.update(time.time() - output_msg.timestamp)
+ self.frame_fps.update(
+ 1.0 / (output_msg.timestamp - self.last_output_msg.timestamp))
+ self.last_output_msg = output_msg
+
+ return output_msg
+
+ def _get_node_info(self):
+ info = super()._get_node_info()
+ info['object_fps'] = self.object_fps.average()
+ info['object_lag (ms)'] = self.object_lag.average() * 1000
+ info['frame_fps'] = self.frame_fps.average()
+ info['frame_lag (ms)'] = self.frame_lag.average() * 1000
+ return info
diff --git a/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py b/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..b35a77869248817d34ed6ee499d89dfbc0948c80
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/helper_nodes/recorder_node.py
@@ -0,0 +1,126 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from queue import Full, Queue
+from threading import Thread
+from typing import List, Union
+
+import cv2
+
+from ..node import Node
+from ..registry import NODES
+
+
+@NODES.register_module()
+class RecorderNode(Node):
+ """Record the video frames into a local file.
+
+ :class:`RecorderNode` uses OpenCV backend to record the video. Recording
+ is performed in a separate thread to avoid blocking the data stream. A
+ buffer queue is used to cached the arrived frame images.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ out_video_file (str): The path of the output video file
+ out_video_fps (int): The frame rate of the output video. Default: 30
+ out_video_codec (str): The codec of the output video. Default: 'mp4v'
+ buffer_size (int): Size of the buffer queue that caches the arrived
+ frame images.
+ enable (bool): Default enable/disable status. Default: ``True``.
+
+ Example::
+ >>> cfg = dict(
+ ... type='RecorderNode',
+ ... name='recorder',
+ ... out_video_file='webcam_demo.mp4',
+ ... input_buffer='display',
+ ... output_buffer='_display_'
+ ... # `_display_` is an executor-reserved buffer
+ ... )
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ def __init__(
+ self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ out_video_file: str,
+ out_video_fps: int = 30,
+ out_video_codec: str = 'mp4v',
+ buffer_size: int = 30,
+ enable: bool = True,
+ ):
+ super().__init__(name=name, enable_key=None, enable=enable)
+
+ self.queue = Queue(maxsize=buffer_size)
+ self.out_video_file = out_video_file
+ self.out_video_fps = out_video_fps
+ self.out_video_codec = out_video_codec
+ self.vwriter = None
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', trigger=True)
+ self.register_output_buffer(output_buffer)
+
+ # Start a new thread to write frame
+ self.t_record = Thread(target=self._record, args=(), daemon=True)
+ self.t_record.start()
+
+ def process(self, input_msgs):
+
+ input_msg = input_msgs['input']
+ img = input_msg.get_image() if input_msg is not None else None
+ img_queued = False
+
+ while not img_queued:
+ try:
+ self.queue.put(img, timeout=1)
+ img_queued = True
+ self.logger.info('Recorder received one frame.')
+ except Full:
+ self.logger.warn('Recorder jamed!')
+
+ return input_msg
+
+ def _record(self):
+ """This method is used to create a thread to get frame images from the
+ buffer queue and write them into the file."""
+
+ while True:
+
+ img = self.queue.get()
+
+ if img is None:
+ break
+
+ if self.vwriter is None:
+ fourcc = cv2.VideoWriter_fourcc(*self.out_video_codec)
+ fps = self.out_video_fps
+ frame_size = (img.shape[1], img.shape[0])
+ self.vwriter = cv2.VideoWriter(self.out_video_file, fourcc,
+ fps, frame_size)
+ assert self.vwriter.isOpened()
+
+ self.vwriter.write(img)
+
+ self.logger.info('Recorder released.')
+ if self.vwriter is not None:
+ self.vwriter.release()
+
+ def on_exit(self):
+ try:
+ # Try putting a None into the output queue so the self.vwriter will
+ # be released after all queue frames have been written to file.
+ self.queue.put(None, timeout=1)
+ self.t_record.join(timeout=1)
+ except Full:
+ pass
+
+ if self.t_record.is_alive():
+ # Force to release self.vwriter
+ self.logger.warn('Recorder forced release!')
+ if self.vwriter is not None:
+ self.vwriter.release()
diff --git a/mmpose/apis/webcam/nodes/model_nodes/__init__.py b/mmpose/apis/webcam/nodes/model_nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a116bfec8d7011edeca54ac8a867c6fc0e05c3
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/model_nodes/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .detector_node import DetectorNode
+from .pose_estimator_node import TopdownPoseEstimatorNode
+
+__all__ = ['DetectorNode', 'TopdownPoseEstimatorNode']
diff --git a/mmpose/apis/webcam/nodes/model_nodes/detector_node.py b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..350831fe62a0a53b363bb24ab37ac9b0fe33b260
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from mmpose.utils import adapt_mmdet_pipeline
+from ...utils import get_config_path
+from ..node import Node
+from ..registry import NODES
+
+try:
+ from mmdet.apis import inference_detector, init_detector
+ has_mmdet = True
+except (ImportError, ModuleNotFoundError):
+ has_mmdet = False
+
+
+@NODES.register_module()
+class DetectorNode(Node):
+ """Detect objects from the frame image using MMDetection model.
+
+ Note that MMDetection is required for this node. Please refer to
+ `MMDetection documentation `_ for the installation guide.
+
+ Parameters:
+ name (str): The node name (also thread name)
+ model_cfg (str): The model config file
+ model_checkpoint (str): The model checkpoint file
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ device (str): Specify the device to hold model weights and inference
+ the model. Default: ``'cuda:0'``
+ bbox_thr (float): Set a threshold to filter out objects with low bbox
+ scores. Default: 0.5
+ multi_input (bool): Whether load all frames in input buffer. If True,
+ all frames in buffer will be loaded and stacked. The latest frame
+ is used to detect objects of interest. Default: False
+
+ Example::
+ >>> cfg = dict(
+ ... type='DetectorNode',
+ ... name='detector',
+ ... model_config='demo/mmdetection_cfg/'
+ ... 'ssdlite_mobilenetv2_scratch_600e_coco.py',
+ ... model_checkpoint='https://download.openmmlab.com'
+ ... '/mmdetection/v2.0/ssd/'
+ ... 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
+ ... 'scratch_600e_coco_20210629_110627-974d9307.pth',
+ ... # `_input_` is an executor-reserved buffer
+ ... input_buffer='_input_',
+ ... output_buffer='det_result')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ def __init__(self,
+ name: str,
+ model_config: str,
+ model_checkpoint: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ device: str = 'cuda:0',
+ bbox_thr: float = 0.5,
+ multi_input: bool = False):
+ # Check mmdetection is installed
+ assert has_mmdet, \
+ f'MMDetection is required for {self.__class__.__name__}.'
+
+ super().__init__(
+ name=name,
+ enable_key=enable_key,
+ enable=enable,
+ multi_input=multi_input)
+
+ self.model_config = get_config_path(model_config, 'mmdet')
+ self.model_checkpoint = model_checkpoint
+ self.device = device.lower()
+ self.bbox_thr = bbox_thr
+
+ # Init model
+ self.model = init_detector(
+ self.model_config, self.model_checkpoint, device=self.device)
+ self.model.cfg = adapt_mmdet_pipeline(self.model.cfg)
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', trigger=True)
+ self.register_output_buffer(output_buffer)
+
+ def bypass(self, input_msgs):
+ return input_msgs['input']
+
+ def process(self, input_msgs):
+ input_msg = input_msgs['input']
+
+ if self.multi_input:
+ imgs = [frame.get_image() for frame in input_msg]
+ input_msg = input_msg[-1]
+
+ img = input_msg.get_image()
+
+ preds = inference_detector(self.model, img)
+ objects = self._post_process(preds)
+ input_msg.update_objects(objects)
+
+ if self.multi_input:
+ input_msg.set_image(np.stack(imgs, axis=0))
+
+ return input_msg
+
+ def _post_process(self, preds) -> List[Dict]:
+ """Post-process the predictions of MMDetection model."""
+ instances = preds.pred_instances.cpu().numpy()
+
+ classes = self.model.dataset_meta['classes']
+ if isinstance(classes, str):
+ classes = (classes, )
+
+ objects = []
+ for i in range(len(instances)):
+ if instances.scores[i] < self.bbox_thr:
+ continue
+ class_id = instances.labels[i]
+ obj = {
+ 'class_id': class_id,
+ 'label': classes[class_id],
+ 'bbox': instances.bboxes[i],
+ 'det_model_cfg': self.model.cfg,
+ 'dataset_meta': self.model.dataset_meta.copy(),
+ }
+ objects.append(obj)
+ return objects
diff --git a/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py b/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..64691cf5606d950ed18053740d1f6aca822a8386
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/model_nodes/pose_estimator_node.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+
+from mmpose.apis import inference_topdown, init_model
+from ...utils import get_config_path
+from ..node import Node
+from ..registry import NODES
+
+
+@dataclass
+class TrackInfo:
+ """Dataclass for object tracking information."""
+ next_id: int = 0
+ last_objects: List = None
+
+
+@NODES.register_module()
+class TopdownPoseEstimatorNode(Node):
+ """Perform top-down pose estimation using MMPose model.
+
+ The node should be placed after an object detection node.
+
+ Parameters:
+ name (str): The node name (also thread name)
+ model_cfg (str): The model config file
+ model_checkpoint (str): The model checkpoint file
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ device (str): Specify the device to hold model weights and inference
+ the model. Default: ``'cuda:0'``
+ class_ids (list[int], optional): Specify the object category indices
+ to apply pose estimation. If both ``class_ids`` and ``labels``
+ are given, ``labels`` will be ignored. If neither is given, pose
+ estimation will be applied for all objects. Default: ``None``
+ labels (list[str], optional): Specify the object category names to
+ apply pose estimation. See also ``class_ids``. Default: ``None``
+ bbox_thr (float): Set a threshold to filter out objects with low bbox
+ scores. Default: 0.5
+
+ Example::
+ >>> cfg = dict(
+ ... type='TopdownPoseEstimatorNode',
+ ... name='human pose estimator',
+ ... model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
+ ... 'topdown_heatmap/coco-wholebody/'
+ ... 'vipnas_mbv3_coco_wholebody_256x192_dark.py',
+ ... model_checkpoint='https://download.openmmlab.com/mmpose/'
+ ... 'top_down/vipnas/vipnas_mbv3_coco_wholebody_256x192_dark'
+ ... '-e2158108_20211205.pth',
+ ... labels=['person'],
+ ... input_buffer='det_result',
+ ... output_buffer='human_pose')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ def __init__(self,
+ name: str,
+ model_config: str,
+ model_checkpoint: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ device: str = 'cuda:0',
+ class_ids: Optional[List[int]] = None,
+ labels: Optional[List[str]] = None,
+ bbox_thr: float = 0.5):
+ super().__init__(name=name, enable_key=enable_key, enable=enable)
+
+ # Init model
+ self.model_config = get_config_path(model_config, 'mmpose')
+ self.model_checkpoint = model_checkpoint
+ self.device = device.lower()
+
+ self.class_ids = class_ids
+ self.labels = labels
+ self.bbox_thr = bbox_thr
+
+ # Init model
+ self.model = init_model(
+ self.model_config, self.model_checkpoint, device=self.device)
+
+ # Register buffers
+ self.register_input_buffer(input_buffer, 'input', trigger=True)
+ self.register_output_buffer(output_buffer)
+
+ def bypass(self, input_msgs):
+ return input_msgs['input']
+
+ def process(self, input_msgs):
+
+ input_msg = input_msgs['input']
+ img = input_msg.get_image()
+
+ if self.class_ids:
+ objects = input_msg.get_objects(
+ lambda x: x.get('class_id') in self.class_ids)
+ elif self.labels:
+ objects = input_msg.get_objects(
+ lambda x: x.get('label') in self.labels)
+ else:
+ objects = input_msg.get_objects()
+
+ if len(objects) > 0:
+ # Inference pose
+ bboxes = np.stack([object['bbox'] for object in objects])
+ pose_results = inference_topdown(self.model, img, bboxes)
+
+ # Update objects
+ for pose_result, object in zip(pose_results, objects):
+ pred_instances = pose_result.pred_instances
+ object['keypoints'] = pred_instances.keypoints[0]
+ object['keypoint_scores'] = pred_instances.keypoint_scores[0]
+
+ dataset_meta = self.model.dataset_meta.copy()
+ dataset_meta.update(object.get('dataset_meta', dict()))
+ object['dataset_meta'] = dataset_meta
+ object['pose_model_cfg'] = self.model.cfg
+
+ input_msg.update_objects(objects)
+
+ return input_msg
diff --git a/mmpose/apis/webcam/nodes/node.py b/mmpose/apis/webcam/nodes/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d34ae1cc081f7a71f5f62397b40ceee7b894503
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/node.py
@@ -0,0 +1,407 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import time
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+from threading import Thread
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+from mmengine import is_method_overridden
+
+from mmpose.utils import StopWatch
+from ..utils import Message, VideoEndingMessage, limit_max_fps
+
+
+@dataclass
+class BufferInfo():
+ """Dataclass for buffer information."""
+ buffer_name: str
+ input_name: Optional[str] = None
+ trigger: bool = False
+
+
+@dataclass
+class EventInfo():
+ """Dataclass for event handler information."""
+ event_name: str
+ is_keyboard: bool = False
+ handler_func: Optional[Callable] = None
+
+
+class Node(Thread, metaclass=ABCMeta):
+ """Base class for node, which is the interface of basic function module.
+
+ :class:`Node` inherits :class:`threading.Thread`. All subclasses should
+ override following methods:
+
+ - ``process()``
+ - ``bypass()`` (optional)
+
+
+ Parameters:
+ name (str): The node name (also thread name)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ max_fps (int): Maximum FPS of the node. This is to avoid the node
+ running unrestrictedly and causing large resource consuming.
+ Default: 30
+ input_check_interval (float): Minimum interval (in millisecond) between
+ checking if input is ready. Default: 0.001
+ enable (bool): Default enable/disable status. Default: ``True``
+ daemon (bool): Whether node is a daemon. Default: ``True``
+ multi_input (bool): Whether load all messages in buffer. If False,
+ only one message will be loaded each time. Default: ``False``
+ """
+
+ def __init__(self,
+ name: str,
+ enable_key: Optional[Union[str, int]] = None,
+ max_fps: int = 30,
+ input_check_interval: float = 0.01,
+ enable: bool = True,
+ daemon: bool = False,
+ multi_input: bool = False):
+ super().__init__(name=name, daemon=daemon)
+ self._executor = None
+ self._enabled = enable
+ self.enable_key = enable_key
+ self.max_fps = max_fps
+ self.input_check_interval = input_check_interval
+ self.multi_input = multi_input
+
+ # A partitioned buffer manager the executor's buffer manager that
+ # only accesses the buffers related to the node
+ self._buffer_manager = None
+
+ # Input/output buffers are a list of registered buffers' information
+ self._input_buffers = []
+ self._output_buffers = []
+
+ # Event manager is a copy of assigned executor's event manager
+ self._event_manager = None
+
+ # A list of registered event information
+ # See register_event() for more information
+ # Note that we recommend to handle events in nodes by registering
+ # handlers, but one can still access the raw event by _event_manager
+ self._registered_events = []
+
+ # A list of (listener_threads, event_info)
+ # See set_executor() for more information
+ self._event_listener_threads = []
+
+ # A timer to calculate node FPS
+ self._timer = StopWatch(window=10)
+
+ # Register enable toggle key
+ if self.enable_key:
+ # If the node allows toggling enable, it should override the
+ # `bypass` method to define the node behavior when disabled.
+ if not is_method_overridden('bypass', Node, self.__class__):
+ raise NotImplementedError(
+ f'The node {self.__class__} does not support toggling'
+ 'enable but got argument `enable_key`. To support toggling'
+ 'enable, please override the `bypass` method of the node.')
+
+ self.register_event(
+ event_name=self.enable_key,
+ is_keyboard=True,
+ handler_func=self._toggle_enable,
+ )
+
+ # Logger
+ self.logger = logging.getLogger(f'Node "{self.name}"')
+
+ @property
+ def registered_buffers(self):
+ return self._input_buffers + self._output_buffers
+
+ @property
+ def registered_events(self):
+ return self._registered_events.copy()
+
+ def _toggle_enable(self):
+ self._enabled = not self._enabled
+
+ def register_input_buffer(self,
+ buffer_name: str,
+ input_name: str,
+ trigger: bool = False):
+ """Register an input buffer, so that Node can automatically check if
+ data is ready, fetch data from the buffers and format the inputs to
+ feed into `process` method.
+
+ The subclass of Node should invoke `register_input_buffer` in its
+ `__init__` method. This method can be invoked multiple times to
+ register multiple input buffers.
+
+ Args:
+ buffer_name (str): The name of the buffer
+ input_name (str): The name of the fetched message from the
+ corresponding buffer
+ trigger (bool): An trigger input means the node will wait
+ until the input is ready before processing. Otherwise, an
+ inessential input will not block the processing, instead
+ a None will be fetched if the buffer is not ready.
+ """
+ buffer_info = BufferInfo(buffer_name, input_name, trigger)
+ self._input_buffers.append(buffer_info)
+
+ def register_output_buffer(self, buffer_name: Union[str, List[str]]):
+ """Register one or multiple output buffers, so that the Node can
+ automatically send the output of the `process` method to these buffers.
+
+ The subclass of Node should invoke `register_output_buffer` in its
+ `__init__` method.
+
+ Args:
+ buffer_name (str|list): The name(s) of the output buffer(s).
+ """
+
+ if not isinstance(buffer_name, list):
+ buffer_name = [buffer_name]
+
+ for name in buffer_name:
+ buffer_info = BufferInfo(name)
+ self._output_buffers.append(buffer_info)
+
+ def register_event(self,
+ event_name: str,
+ is_keyboard: bool = False,
+ handler_func: Optional[Callable] = None):
+ """Register an event. All events used in the node need to be registered
+ in __init__(). If a callable handler is given, a thread will be create
+ to listen and handle the event when the node starts.
+
+ Args:
+ Args:
+ event_name (str|int): The event name. If is_keyboard==True,
+ event_name should be a str (as char) or an int (as ascii)
+ is_keyboard (bool): Indicate whether it is an keyboard
+ event. If True, the argument event_name will be regarded as a
+ key indicator.
+ handler_func (callable, optional): The event handler function,
+ which should be a collable object with no arguments or
+ return values. Default: ``None``.
+ """
+ event_info = EventInfo(event_name, is_keyboard, handler_func)
+ self._registered_events.append(event_info)
+
+ def set_executor(self, executor):
+ """Assign the node to an executor so the node can access the buffers
+ and event manager of the executor.
+
+ This method should be invoked by the executor instance.
+
+ Args:
+ executor (:obj:`WebcamExecutor`): The executor to hold the node
+ """
+ # Get partitioned buffer manager
+ buffer_names = [
+ buffer.buffer_name
+ for buffer in self._input_buffers + self._output_buffers
+ ]
+ self._buffer_manager = executor.buffer_manager.get_sub_manager(
+ buffer_names)
+
+ # Get event manager
+ self._event_manager = executor.event_manager
+
+ def _get_input_from_buffer(self) -> Tuple[bool, Optional[Dict]]:
+ """Get and pack input data.
+
+ The function returns a tuple (status, data). If the trigger buffers
+ are ready, the status flag will be True, and the packed data is a dict
+ whose items are buffer names and corresponding messages (unready
+ non-trigger buffers will give a `None`). Otherwise, the status flag is
+ False and the packed data is None.
+
+ Returns:
+ tuple[bool, dict]: The first item is a bool value indicating
+ whether input is ready (i.e., all tirgger buffers are ready). The
+ second value is a dict of buffer names and messages.
+ """
+ buffer_manager = self._buffer_manager
+
+ if buffer_manager is None:
+ raise ValueError(f'Node "{self.name}": not set to an executor.')
+
+ # Check that trigger buffers are ready
+ for buffer_info in self._input_buffers:
+ if buffer_info.trigger and buffer_manager.is_empty(
+ buffer_info.buffer_name):
+ return False, None
+
+ # Default input
+ result = {
+ buffer_info.input_name: None
+ for buffer_info in self._input_buffers
+ }
+
+ for buffer_info in self._input_buffers:
+
+ while not buffer_manager.is_empty(buffer_info.buffer_name):
+ msg = buffer_manager.get(buffer_info.buffer_name, block=False)
+ if self.multi_input:
+ if result[buffer_info.input_name] is None:
+ result[buffer_info.input_name] = []
+ result[buffer_info.input_name].append(msg)
+ else:
+ result[buffer_info.input_name] = msg
+ break
+
+ # Return unsuccessful flag if any trigger input is unready
+ if buffer_info.trigger and result[buffer_info.input_name] is None:
+ return False, None
+
+ return True, result
+
+ def _send_output_to_buffers(self, output_msg):
+ """Send output of ``process()`` to the registered output buffers.
+
+ Args:
+ output_msg (Message): output message
+ """
+ for buffer_info in self._output_buffers:
+ buffer_name = buffer_info.buffer_name
+ self._buffer_manager.put_force(buffer_name, output_msg)
+
+ @abstractmethod
+ def process(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ """The method that implements the function of the node.
+
+ This method will be invoked when the node is enabled and the input
+ data is ready. All subclasses of Node should override this method.
+
+ Args:
+ input_msgs (dict[str, :obj:`Message`]): The input data collected
+ from the buffers. For each item, the key is the `input_name`
+ of the registered input buffer, and the value is a Message
+ instance fetched from the buffer (or None if the buffer is
+ non-trigger and not ready).
+
+ Returns:
+ Message: The output message of the node which will be send to all
+ registered output buffers.
+ """
+
+ def bypass(self, input_msgs: Dict[str, Message]) -> Union[Message, None]:
+ """The method that defines the node behavior when disabled.
+
+ Note that a node must override this method if it has `enable_key`.
+ This method has the same signature as ``process()``.
+
+ Args:
+ input_msgs (dict[str, :obj:`Message`]): The input data collected
+ from the buffers. For each item, the key is the `input_name`
+ of the registered input buffer, and the value is a Message
+ instance fetched from the buffer (or None if the buffer is
+ non-trigger and not ready).
+
+ Returns:
+ Message: The output message of the node which will be send to all
+ registered output buffers.
+ """
+ raise NotImplementedError
+
+ def _get_node_info(self) -> Dict:
+ """Get route information of the node.
+
+ Default information includes:
+ - ``'fps'``: The processing speed of the node
+ - ``'timestamp'``: The time that this method is invoked
+
+ Subclasses can override this method to customize the node information.
+
+ Returns:
+ dict: The items of node information
+ """
+ info = {'fps': self._timer.report('_FPS_'), 'timestamp': time.time()}
+ return info
+
+ def on_exit(self):
+ """This method will be invoked on event `_exit_`.
+
+ Subclasses should override this method to specifying the exiting
+ behavior.
+ """
+
+ def run(self):
+ """Method representing the Node's activity.
+
+ This method override the standard ``run()`` method of Thread.
+ Subclasses of :class:`Node` should not override this method in
+ subclasses.
+ """
+
+ self.logger.info('Process starts.')
+
+ # Create event listener threads
+ for event_info in self._registered_events:
+
+ if event_info.handler_func is None:
+ continue
+
+ def event_listener():
+ while True:
+ with self._event_manager.wait_and_handle(
+ event_info.event_name, event_info.is_keyboard):
+ event_info.handler_func()
+
+ t_listener = Thread(target=event_listener, args=(), daemon=True)
+ t_listener.start()
+ self._event_listener_threads.append(t_listener)
+
+ # Loop
+ while True:
+ # Exit
+ if self._event_manager.is_set('_exit_'):
+ self.on_exit()
+ break
+
+ # Check if input is ready
+ input_status, input_msgs = self._get_input_from_buffer()
+
+ # Input is not ready
+ if not input_status:
+ time.sleep(self.input_check_interval)
+ continue
+
+ # If a VideoEndingMessage is received, broadcast the signal
+ # without invoking process() or bypass()
+ video_ending = False
+ for _, msg in input_msgs.items():
+ if isinstance(msg, VideoEndingMessage):
+ self._send_output_to_buffers(msg)
+ video_ending = True
+ break
+
+ if video_ending:
+ self.on_exit()
+ break
+
+ # Check if enabled
+ if not self._enabled:
+ # Override bypass method to define node behavior when disabled
+ output_msg = self.bypass(input_msgs)
+ else:
+ with self._timer.timeit():
+ with limit_max_fps(self.max_fps):
+ # Process
+ output_msg = self.process(input_msgs)
+
+ if output_msg:
+ # Update route information
+ node_info = self._get_node_info()
+ output_msg.update_route_info(node=self, info=node_info)
+
+ # Send output message
+ if output_msg is not None:
+ self._send_output_to_buffers(output_msg)
+
+ self.logger.info('Process ends.')
diff --git a/mmpose/apis/webcam/nodes/registry.py b/mmpose/apis/webcam/nodes/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..06d39fed63bb1972d5a59892c8d3d208f113e792
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/registry.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.registry import Registry
+
+NODES = Registry('node')
diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py b/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad7e3037673fb2ef01fdcd709fb9e2212ecef5a
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bigeye_effect_node import BigeyeEffectNode
+from .notice_board_node import NoticeBoardNode
+from .object_visualizer_node import ObjectVisualizerNode
+from .sunglasses_effect_node import SunglassesEffectNode
+
+__all__ = [
+ 'ObjectVisualizerNode', 'NoticeBoardNode', 'SunglassesEffectNode',
+ 'BigeyeEffectNode'
+]
diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bbec3d670b90528dd234ac57fb792e379e8a9f5
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/visualizer_nodes/bigeye_effect_node.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import groupby
+from typing import Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+
+from ...utils import get_eye_keypoint_ids
+from ..base_visualizer_node import BaseVisualizerNode
+from ..registry import NODES
+
+
+@NODES.register_module()
+class BigeyeEffectNode(BaseVisualizerNode):
+ """Apply big-eye effect to the objects with eye keypoints in the frame.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ kpt_thr (float): The score threshold of valid keypoints. Default: 0.5
+
+ Example::
+ >>> cfg = dict(
+ ... type='SunglassesEffectNode',
+ ... name='sunglasses',
+ ... enable_key='s',
+ ... enable=False,
+ ... input_buffer='vis',
+ ... output_buffer='vis_sunglasses')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ kpt_thr: float = 0.5):
+
+ super().__init__(
+ name=name,
+ input_buffer=input_buffer,
+ output_buffer=output_buffer,
+ enable_key=enable_key,
+ enable=enable)
+ self.kpt_thr = kpt_thr
+
+ def draw(self, input_msg):
+ canvas = input_msg.get_image()
+
+ objects = input_msg.get_objects(lambda x:
+ ('keypoints' in x and 'bbox' in x))
+
+ for dataset_meta, group in groupby(objects,
+ lambda x: x['dataset_meta']):
+ left_eye_index, right_eye_index = get_eye_keypoint_ids(
+ dataset_meta)
+ canvas = self.apply_bigeye_effect(canvas, group, left_eye_index,
+ right_eye_index)
+ return canvas
+
+ def apply_bigeye_effect(self, canvas: np.ndarray, objects: List[Dict],
+ left_eye_index: int,
+ right_eye_index: int) -> np.ndarray:
+ """Apply big-eye effect.
+
+ Args:
+ canvas (np.ndarray): The image to apply the effect
+ objects (list[dict]): The object list with bbox and keypoints
+ - "bbox" ([K, 4(or 5)]): bbox in [x1, y1, x2, y2, (score)]
+ - "keypoints" ([K,3]): keypoints in [x, y, score]
+ left_eye_index (int): Keypoint index of left eye
+ right_eye_index (int): Keypoint index of right eye
+
+ Returns:
+ np.ndarray: Processed image.
+ """
+
+ xx, yy = np.meshgrid(
+ np.arange(canvas.shape[1]), np.arange(canvas.shape[0]))
+ xx = xx.astype(np.float32)
+ yy = yy.astype(np.float32)
+
+ for obj in objects:
+ bbox = obj['bbox']
+ kpts = obj['keypoints']
+ kpt_scores = obj['keypoint_scores']
+
+ if kpt_scores[left_eye_index] < self.kpt_thr or kpt_scores[
+ right_eye_index] < self.kpt_thr:
+ continue
+
+ kpt_leye = kpts[left_eye_index, :2]
+ kpt_reye = kpts[right_eye_index, :2]
+ for xc, yc in [kpt_leye, kpt_reye]:
+
+ # distortion parameters
+ k1 = 0.001
+ epe = 1e-5
+
+ scale = (bbox[2] - bbox[0])**2 + (bbox[3] - bbox[1])**2
+ r2 = ((xx - xc)**2 + (yy - yc)**2)
+ r2 = (r2 + epe) / scale # normalized by bbox scale
+
+ xx = (xx - xc) / (1 + k1 / r2) + xc
+ yy = (yy - yc) / (1 + k1 / r2) + yc
+
+ canvas = cv2.remap(
+ canvas,
+ xx,
+ yy,
+ interpolation=cv2.INTER_AREA,
+ borderMode=cv2.BORDER_REPLICATE)
+
+ return canvas
diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..0578ec38eb399645eef7795577de29f4fb7240b9
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/visualizer_nodes/notice_board_node.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from mmcv import color_val
+
+from ...utils import FrameMessage
+from ..base_visualizer_node import BaseVisualizerNode
+from ..registry import NODES
+
+
+@NODES.register_module()
+class NoticeBoardNode(BaseVisualizerNode):
+ """Show text messages in the frame.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ content_lines (list[str], optional): The lines of text message to show
+ in the frame. If not given, a default message will be shown.
+ Default: ``None``
+ x_offset (int): The position of the notice board's left border in
+ pixels. Default: 20
+ y_offset (int): The position of the notice board's top border in
+ pixels. Default: 20
+ y_delta (int): The line height in pixels. Default: 15
+ text_color (str|tuple): The font color represented in a color name or
+ a BGR tuple. Default: ``'black'``
+ backbround_color (str|tuple): The background color represented in a
+ color name or a BGR tuple. Default: (255, 183, 0)
+ text_scale (float): The font scale factor that is multiplied by the
+ base size. Default: 0.4
+
+ Example::
+ >>> cfg = dict(
+ ... type='NoticeBoardNode',
+ ... name='instruction',
+ ... enable_key='h',
+ ... enable=True,
+ ... input_buffer='vis_bigeye',
+ ... output_buffer='vis_notice',
+ ... content_lines=[
+ ... 'This is a demo for pose visualization and simple image '
+ ... 'effects. Have fun!', '', 'Hot-keys:',
+ ... '"v": Pose estimation result visualization',
+ ... '"s": Sunglasses effect B-)', '"b": Big-eye effect 0_0',
+ ... '"h": Show help information',
+ ... '"m": Show diagnostic information', '"q": Exit'
+ ... ],
+ ... )
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ default_content_lines = ['This is a notice board!']
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ content_lines: Optional[List[str]] = None,
+ x_offset: int = 20,
+ y_offset: int = 20,
+ y_delta: int = 15,
+ text_color: Union[str, Tuple[int, int, int]] = 'black',
+ background_color: Union[str, Tuple[int, int,
+ int]] = (255, 183, 0),
+ text_scale: float = 0.4):
+ super().__init__(
+ name=name,
+ input_buffer=input_buffer,
+ output_buffer=output_buffer,
+ enable_key=enable_key,
+ enable=enable)
+
+ self.x_offset = x_offset
+ self.y_offset = y_offset
+ self.y_delta = y_delta
+ self.text_color = color_val(text_color)
+ self.background_color = color_val(background_color)
+ self.text_scale = text_scale
+
+ if content_lines:
+ self.content_lines = content_lines
+ else:
+ self.content_lines = self.default_content_lines
+
+ def draw(self, input_msg: FrameMessage) -> np.ndarray:
+ img = input_msg.get_image()
+ canvas = np.full(img.shape, self.background_color, dtype=img.dtype)
+
+ x = self.x_offset
+ y = self.y_offset
+
+ max_len = max([len(line) for line in self.content_lines])
+
+ def _put_line(line=''):
+ nonlocal y
+ cv2.putText(canvas, line, (x, y), cv2.FONT_HERSHEY_DUPLEX,
+ self.text_scale, self.text_color, 1)
+ y += self.y_delta
+
+ for line in self.content_lines:
+ _put_line(line)
+
+ x1 = max(0, self.x_offset)
+ x2 = min(img.shape[1], int(x + max_len * self.text_scale * 20))
+ y1 = max(0, self.y_offset - self.y_delta)
+ y2 = min(img.shape[0], y)
+
+ src1 = canvas[y1:y2, x1:x2]
+ src2 = img[y1:y2, x1:x2]
+ img[y1:y2, x1:x2] = cv2.addWeighted(src1, 0.5, src2, 0.5, 0)
+
+ return img
diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef28a0804cd86352700f9636bd97751b31383ce1
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/visualizer_nodes/object_visualizer_node.py
@@ -0,0 +1,341 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from itertools import groupby
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import mmcv
+import numpy as np
+
+from ...utils import FrameMessage
+from ..base_visualizer_node import BaseVisualizerNode
+from ..registry import NODES
+
+
+def imshow_bboxes(img,
+ bboxes,
+ labels=None,
+ colors='green',
+ text_color='white',
+ thickness=1,
+ font_scale=0.5):
+ """Draw bboxes with labels (optional) on an image. This is a wrapper of
+ mmcv.imshow_bboxes.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in
+ format [x1, y1, x2, y2].
+ labels (str or list[str], optional): labels of each bbox.
+ colors (list[str or tuple or :obj:`Color`]): A list of colors.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+
+ # adapt to mmcv.imshow_bboxes input format
+ bboxes = np.split(
+ bboxes, bboxes.shape[0], axis=0) if bboxes.shape[0] > 0 else []
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [mmcv.color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+
+ img = mmcv.imshow_bboxes(
+ img,
+ bboxes,
+ colors,
+ top_k=-1,
+ thickness=thickness,
+ show=False,
+ out_file=None)
+
+ if labels is not None:
+ if not isinstance(labels, list):
+ labels = [labels for _ in range(len(bboxes))]
+ assert len(labels) == len(bboxes)
+
+ for bbox, label, color in zip(bboxes, labels, colors):
+ if label is None:
+ continue
+ bbox_int = bbox[0, :4].astype(np.int32)
+ # roughly estimate the proper font size
+ text_size, text_baseline = cv2.getTextSize(label,
+ cv2.FONT_HERSHEY_DUPLEX,
+ font_scale, thickness)
+ text_x1 = bbox_int[0]
+ text_y1 = max(0, bbox_int[1] - text_size[1] - text_baseline)
+ text_x2 = bbox_int[0] + text_size[0]
+ text_y2 = text_y1 + text_size[1] + text_baseline
+ cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
+ cv2.FILLED)
+ cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
+ cv2.FONT_HERSHEY_DUPLEX, font_scale,
+ mmcv.color_val(text_color), thickness)
+
+ return img
+
+
+def imshow_keypoints(img,
+ pose_result,
+ skeleton=None,
+ kpt_score_thr=0.3,
+ pose_kpt_color=None,
+ pose_link_color=None,
+ radius=4,
+ thickness=1,
+ show_keypoint_weight=False):
+ """Draw keypoints and links on an image.
+
+ Args:
+ img (str or Tensor): The image to draw poses on. If an image array
+ is given, id will be modified in-place.
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
+ keypoint is represented as x, y, score.
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
+ the keypoint will not be drawn.
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
+ links will not be drawn.
+ thickness (int): Thickness of lines.
+ """
+
+ img = mmcv.imread(img)
+ img_h, img_w, _ = img.shape
+
+ for kpts in pose_result:
+
+ kpts = np.array(kpts, copy=False)
+
+ # draw each point on image
+ if pose_kpt_color is not None:
+ assert len(pose_kpt_color) == len(kpts)
+
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
+
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
+ # skip the point that should not be drawn
+ continue
+
+ color = tuple(int(c) for c in pose_kpt_color[kid])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ cv2.circle(img_copy, (int(x_coord), int(y_coord)), radius,
+ color, -1)
+ transparency = max(0, min(1, kpt_score))
+ cv2.addWeighted(
+ img_copy,
+ transparency,
+ img,
+ 1 - transparency,
+ 0,
+ dst=img)
+ else:
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius,
+ color, -1)
+
+ # draw links
+ if skeleton is not None and pose_link_color is not None:
+ assert len(pose_link_color) == len(skeleton)
+
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0
+ or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w
+ or pos2[1] <= 0 or pos2[1] >= img_h
+ or kpts[sk[0], 2] < kpt_score_thr
+ or kpts[sk[1], 2] < kpt_score_thr
+ or pose_link_color[sk_id] is None):
+ # skip the link that should not be drawn
+ continue
+ color = tuple(int(c) for c in pose_link_color[sk_id])
+ if show_keypoint_weight:
+ img_copy = img.copy()
+ X = (pos1[0], pos2[0])
+ Y = (pos1[1], pos2[1])
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ stickwidth = 2
+ polygon = cv2.ellipse2Poly(
+ (int(mX), int(mY)), (int(length / 2), int(stickwidth)),
+ int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(img_copy, polygon, color)
+ transparency = max(
+ 0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])))
+ cv2.addWeighted(
+ img_copy,
+ transparency,
+ img,
+ 1 - transparency,
+ 0,
+ dst=img)
+ else:
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
+
+ return img
+
+
+@NODES.register_module()
+class ObjectVisualizerNode(BaseVisualizerNode):
+ """Visualize the bounding box and keypoints of objects.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note: (1) If ``enable_key`` is set,
+ the ``bypass()`` method need to be overridden to define the node
+ behavior when disabled; (2) Some hot-keys are reserved for
+ particular use. For example: 'q', 'Q' and 27 are used for exiting.
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``
+ show_bbox (bool): Set ``True`` to show the bboxes of detection
+ objects. Default: ``True``
+ show_keypoint (bool): Set ``True`` to show the pose estimation
+ results. Default: ``True``
+ must_have_bbox (bool): Only show objects with keypoints.
+ Default: ``False``
+ kpt_thr (float): The threshold of keypoint score. Default: 0.3
+ radius (int): The radius of keypoint. Default: 4
+ thickness (int): The thickness of skeleton. Default: 2
+ bbox_color (str|tuple|dict): The color of bboxes. If a single color is
+ given (a str like 'green' or a BGR tuple like (0, 255, 0)), it
+ will be used for all bboxes. If a dict is given, it will be used
+ as a map from class labels to bbox colors. If not given, a default
+ color map will be used. Default: ``None``
+
+ Example::
+ >>> cfg = dict(
+ ... type='ObjectVisualizerNode',
+ ... name='object visualizer',
+ ... enable_key='v',
+ ... enable=True,
+ ... show_bbox=True,
+ ... must_have_keypoint=False,
+ ... show_keypoint=True,
+ ... input_buffer='frame',
+ ... output_buffer='vis')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ default_bbox_color = {
+ 'person': (148, 139, 255),
+ 'cat': (255, 255, 0),
+ 'dog': (255, 255, 0),
+ }
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ show_bbox: bool = True,
+ show_keypoint: bool = True,
+ must_have_keypoint: bool = False,
+ kpt_thr: float = 0.3,
+ radius: int = 4,
+ thickness: int = 2,
+ bbox_color: Optional[Union[str, Tuple, Dict]] = 'green'):
+
+ super().__init__(
+ name=name,
+ input_buffer=input_buffer,
+ output_buffer=output_buffer,
+ enable_key=enable_key,
+ enable=enable)
+
+ self.kpt_thr = kpt_thr
+ self.bbox_color = bbox_color
+ self.show_bbox = show_bbox
+ self.show_keypoint = show_keypoint
+ self.must_have_keypoint = must_have_keypoint
+ self.radius = radius
+ self.thickness = thickness
+
+ def _draw_bbox(self, canvas: np.ndarray, input_msg: FrameMessage):
+ """Draw object bboxes."""
+
+ if self.must_have_keypoint:
+ objects = input_msg.get_objects(
+ lambda x: 'bbox' in x and 'keypoints' in x)
+ else:
+ objects = input_msg.get_objects(lambda x: 'bbox' in x)
+ # return if there is no detected objects
+ if not objects:
+ return canvas
+
+ bboxes = [obj['bbox'] for obj in objects]
+ labels = [obj.get('label', None) for obj in objects]
+ default_color = (0, 255, 0)
+
+ # Get bbox colors
+ if isinstance(self.bbox_color, dict):
+ colors = [
+ self.bbox_color.get(label, default_color) for label in labels
+ ]
+ else:
+ colors = self.bbox_color
+
+ imshow_bboxes(
+ canvas,
+ np.vstack(bboxes),
+ labels=labels,
+ colors=colors,
+ text_color='white',
+ font_scale=0.5)
+
+ return canvas
+
+ def _draw_keypoint(self, canvas: np.ndarray, input_msg: FrameMessage):
+ """Draw object keypoints."""
+ objects = input_msg.get_objects(lambda x: 'pose_model_cfg' in x)
+
+ # return if there is no object with keypoints
+ if not objects:
+ return canvas
+
+ for model_cfg, group in groupby(objects,
+ lambda x: x['pose_model_cfg']):
+ dataset_info = objects[0]['dataset_meta']
+ keypoints = [
+ np.concatenate(
+ (obj['keypoints'], obj['keypoint_scores'][:, None]),
+ axis=1) for obj in group
+ ]
+ imshow_keypoints(
+ canvas,
+ keypoints,
+ skeleton=dataset_info['skeleton_links'],
+ kpt_score_thr=self.kpt_thr,
+ pose_kpt_color=dataset_info['keypoint_colors'],
+ pose_link_color=dataset_info['skeleton_link_colors'],
+ radius=self.radius,
+ thickness=self.thickness)
+
+ return canvas
+
+ def draw(self, input_msg: FrameMessage) -> np.ndarray:
+ canvas = input_msg.get_image()
+
+ if self.show_bbox:
+ canvas = self._draw_bbox(canvas, input_msg)
+
+ if self.show_keypoint:
+ canvas = self._draw_keypoint(canvas, input_msg)
+
+ return canvas
diff --git a/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py b/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c011177f57c95a959bc97f167eb58229707b641
--- /dev/null
+++ b/mmpose/apis/webcam/nodes/visualizer_nodes/sunglasses_effect_node.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import groupby
+from typing import Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+
+from ...utils import get_eye_keypoint_ids, load_image_from_disk_or_url
+from ..base_visualizer_node import BaseVisualizerNode
+from ..registry import NODES
+
+
+@NODES.register_module()
+class SunglassesEffectNode(BaseVisualizerNode):
+ """Apply sunglasses effect (draw sunglasses at the facial area)to the
+ objects with eye keypoints in the frame.
+
+ Args:
+ name (str): The node name (also thread name)
+ input_buffer (str): The name of the input buffer
+ output_buffer (str|list): The name(s) of the output buffer(s)
+ enable_key (str|int, optional): Set a hot-key to toggle enable/disable
+ of the node. If an int value is given, it will be treated as an
+ ascii code of a key. Please note:
+ 1. If enable_key is set, the bypass method need to be
+ overridden to define the node behavior when disabled
+ 2. Some hot-key has been use for particular use. For example:
+ 'q', 'Q' and 27 are used for quit
+ Default: ``None``
+ enable (bool): Default enable/disable status. Default: ``True``.
+ kpt_thr (float): The score threshold of valid keypoints. Default: 0.5
+ resource_img_path (str, optional): The resource image path or url.
+ The image should be a pair of sunglasses with white background.
+ If not specified, the url of a default image will be used. See
+ ``SunglassesNode.default_resource_img_path``. Default: ``None``
+
+ Example::
+ >>> cfg = dict(
+ ... type='SunglassesEffectNode',
+ ... name='sunglasses',
+ ... enable_key='s',
+ ... enable=False,
+ ... input_buffer='vis',
+ ... output_buffer='vis_sunglasses')
+
+ >>> from mmpose.apis.webcam.nodes import NODES
+ >>> node = NODES.build(cfg)
+ """
+
+ # The image attributes to:
+ # "https://www.vecteezy.com/vector-art/1932353-summer-sunglasses-
+ # accessory-isolated-icon" by Vecteezy
+ default_resource_img_path = (
+ 'https://user-images.githubusercontent.com/15977946/'
+ '170850839-acc59e26-c6b3-48c9-a9ec-87556edb99ed.jpg')
+
+ def __init__(self,
+ name: str,
+ input_buffer: str,
+ output_buffer: Union[str, List[str]],
+ enable_key: Optional[Union[str, int]] = None,
+ enable: bool = True,
+ kpt_thr: float = 0.5,
+ resource_img_path: Optional[str] = None):
+
+ super().__init__(
+ name=name,
+ input_buffer=input_buffer,
+ output_buffer=output_buffer,
+ enable_key=enable_key,
+ enable=enable)
+
+ if resource_img_path is None:
+ resource_img_path = self.default_resource_img_path
+
+ self.resource_img = load_image_from_disk_or_url(resource_img_path)
+ self.kpt_thr = kpt_thr
+
+ def draw(self, input_msg):
+ canvas = input_msg.get_image()
+
+ objects = input_msg.get_objects(lambda x: 'keypoints' in x)
+
+ for dataset_meta, group in groupby(objects,
+ lambda x: x['dataset_meta']):
+ left_eye_index, right_eye_index = get_eye_keypoint_ids(
+ dataset_meta)
+ canvas = self.apply_sunglasses_effect(canvas, group,
+ left_eye_index,
+ right_eye_index)
+ return canvas
+
+ def apply_sunglasses_effect(self, canvas: np.ndarray, objects: List[Dict],
+ left_eye_index: int,
+ right_eye_index: int) -> np.ndarray:
+ """Apply sunglasses effect.
+
+ Args:
+ canvas (np.ndarray): The image to apply the effect
+ objects (list[dict]): The object list with keypoints
+ - "keypoints" ([K,3]): keypoints in [x, y, score]
+ left_eye_index (int): Keypoint index of the left eye
+ right_eye_index (int): Keypoint index of the right eye
+
+ Returns:
+ np.ndarray: Processed image
+ """
+
+ hm, wm = self.resource_img.shape[:2]
+ # anchor points in the sunglasses image
+ pts_src = np.array([[0.3 * wm, 0.3 * hm], [0.3 * wm, 0.7 * hm],
+ [0.7 * wm, 0.3 * hm], [0.7 * wm, 0.7 * hm]],
+ dtype=np.float32)
+
+ for obj in objects:
+ kpts = obj['keypoints']
+ kpt_scores = obj['keypoint_scores']
+
+ if kpt_scores[left_eye_index] < self.kpt_thr or kpt_scores[
+ right_eye_index] < self.kpt_thr:
+ continue
+
+ kpt_leye = kpts[left_eye_index, :2]
+ kpt_reye = kpts[right_eye_index, :2]
+ # orthogonal vector to the left-to-right eyes
+ vo = 0.5 * (kpt_reye - kpt_leye)[::-1] * [-1, 1]
+
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack(
+ [kpt_reye + vo, kpt_reye - vo, kpt_leye + vo, kpt_leye - vo])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ self.resource_img,
+ h_mat,
+ dsize=(canvas.shape[1], canvas.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with a threshold 200
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 200).astype(np.uint8)
+ canvas = cv2.copyTo(patch, mask, canvas)
+
+ return canvas
diff --git a/mmpose/apis/webcam/utils/__init__.py b/mmpose/apis/webcam/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2911bcd5bf451477d0f4266d89786028cae68293
--- /dev/null
+++ b/mmpose/apis/webcam/utils/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .buffer import BufferManager
+from .event import EventManager
+from .image_capture import ImageCapture
+from .message import FrameMessage, Message, VideoEndingMessage
+from .misc import (copy_and_paste, expand_and_clamp, get_cached_file_path,
+ get_config_path, is_image_file, limit_max_fps,
+ load_image_from_disk_or_url, screen_matting)
+from .pose import (get_eye_keypoint_ids, get_face_keypoint_ids,
+ get_hand_keypoint_ids, get_mouth_keypoint_ids,
+ get_wrist_keypoint_ids)
+
+__all__ = [
+ 'BufferManager', 'EventManager', 'FrameMessage', 'Message',
+ 'limit_max_fps', 'VideoEndingMessage', 'load_image_from_disk_or_url',
+ 'get_cached_file_path', 'screen_matting', 'get_config_path',
+ 'expand_and_clamp', 'copy_and_paste', 'is_image_file', 'ImageCapture',
+ 'get_eye_keypoint_ids', 'get_face_keypoint_ids', 'get_wrist_keypoint_ids',
+ 'get_mouth_keypoint_ids', 'get_hand_keypoint_ids'
+]
diff --git a/mmpose/apis/webcam/utils/buffer.py b/mmpose/apis/webcam/utils/buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7f8b9864ee0136173b161a352138c9fac8e879e
--- /dev/null
+++ b/mmpose/apis/webcam/utils/buffer.py
@@ -0,0 +1,203 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import wraps
+from queue import Queue
+from typing import Any, Dict, List, Optional
+
+from mmengine import is_seq_of
+
+__all__ = ['BufferManager']
+
+
+def check_buffer_registered(exist=True):
+ """A function wrapper to check the buffer existence before it is being used
+ by the wrapped function.
+
+ Args:
+ exist (bool): If set to ``True``, assert the buffer exists; if set to
+ ``False``, assert the buffer does not exist. Default: ``True``
+ """
+
+ def wrapper(func):
+
+ @wraps(func)
+ def wrapped(manager, name, *args, **kwargs):
+ if exist:
+ # Assert buffer exist
+ if name not in manager:
+ raise ValueError(f'Fail to call {func.__name__}: '
+ f'buffer "{name}" is not registered.')
+ else:
+ # Assert buffer not exist
+ if name in manager:
+ raise ValueError(f'Fail to call {func.__name__}: '
+ f'buffer "{name}" is already registered.')
+ return func(manager, name, *args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+class Buffer(Queue):
+
+ def put_force(self, item: Any):
+ """Force to put an item into the buffer.
+
+ If the buffer is already full, the earliest item in the buffer will be
+ remove to make room for the incoming item.
+
+ Args:
+ item (any): The item to put into the buffer
+ """
+ with self.mutex:
+ if self.maxsize > 0:
+ while self._qsize() >= self.maxsize:
+ _ = self._get()
+ self.unfinished_tasks -= 1
+
+ self._put(item)
+ self.unfinished_tasks += 1
+ self.not_empty.notify()
+
+
+class BufferManager():
+ """A helper class to manage multiple buffers.
+
+ Parameters:
+ buffer_type (type): The class to build buffer instances. Default:
+ :class:`mmpose.apis.webcam.utils.buffer.Buffer`.
+ buffers (dict, optional): Create :class:`BufferManager` from existing
+ buffers. Each item should a buffer name and the buffer. If not
+ given, an empty buffer manager will be create. Default: ``None``
+ """
+
+ def __init__(self,
+ buffer_type: type = Buffer,
+ buffers: Optional[Dict] = None):
+ self.buffer_type = buffer_type
+ if buffers is None:
+ self._buffers = {}
+ else:
+ if is_seq_of(list(buffers.values()), buffer_type):
+ self._buffers = buffers.copy()
+ else:
+ raise ValueError('The values of buffers should be instance '
+ f'of {buffer_type}')
+
+ def __contains__(self, name):
+ return name in self._buffers
+
+ @check_buffer_registered(False)
+ def register_buffer(self, name, maxsize: int = 0):
+ """Register a buffer.
+
+ If the buffer already exists, an ValueError will be raised.
+
+ Args:
+ name (any): The buffer name
+ maxsize (int): The capacity of the buffer. If set to 0, the
+ capacity is unlimited. Default: 0
+ """
+ self._buffers[name] = self.buffer_type(maxsize)
+
+ @check_buffer_registered()
+ def put(self, name, item, block: bool = True, timeout: float = None):
+ """Put an item into specified buffer.
+
+ Args:
+ name (any): The buffer name
+ item (any): The item to put into the buffer
+ block (bool): If set to ``True``, block if necessary util a free
+ slot is available in the target buffer. It blocks at most
+ ``timeout`` seconds and raises the ``Full`` exception.
+ Otherwise, put an item on the queue if a free slot is
+ immediately available, else raise the ``Full`` exception.
+ Default: ``True``
+ timeout (float, optional): The most waiting time in seconds if
+ ``block`` is ``True``. Default: ``None``
+ """
+ self._buffers[name].put(item, block, timeout)
+
+ @check_buffer_registered()
+ def put_force(self, name, item):
+ """Force to put an item into specified buffer. If the buffer was full,
+ the earliest item within the buffer will be popped out to make a free
+ slot.
+
+ Args:
+ name (any): The buffer name
+ item (any): The item to put into the buffer
+ """
+ self._buffers[name].put_force(item)
+
+ @check_buffer_registered()
+ def get(self, name, block: bool = True, timeout: float = None) -> Any:
+ """Remove an return an item from the specified buffer.
+
+ Args:
+ name (any): The buffer name
+ block (bool): If set to ``True``, block if necessary until an item
+ is available in the target buffer. It blocks at most
+ ``timeout`` seconds and raises the ``Empty`` exception.
+ Otherwise, return an item if one is immediately available,
+ else raise the ``Empty`` exception. Default: ``True``
+ timeout (float, optional): The most waiting time in seconds if
+ ``block`` is ``True``. Default: ``None``
+
+ Returns:
+ any: The returned item.
+ """
+ return self._buffers[name].get(block, timeout)
+
+ @check_buffer_registered()
+ def is_empty(self, name) -> bool:
+ """Check if a buffer is empty.
+
+ Args:
+ name (any): The buffer name
+
+ Returns:
+ bool: Weather the buffer is empty.
+ """
+ return self._buffers[name].empty()
+
+ @check_buffer_registered()
+ def is_full(self, name):
+ """Check if a buffer is full.
+
+ Args:
+ name (any): The buffer name
+
+ Returns:
+ bool: Weather the buffer is full.
+ """
+ return self._buffers[name].full()
+
+ def get_sub_manager(self, buffer_names: List[str]) -> 'BufferManager':
+ """Return a :class:`BufferManager` instance that covers a subset of the
+ buffers in the parent. The is usually used to partially share the
+ buffers of the executor to the node.
+
+ Args:
+ buffer_names (list): The list of buffers to create the sub manager
+
+ Returns:
+ BufferManager: The created sub buffer manager.
+ """
+ buffers = {name: self._buffers[name] for name in buffer_names}
+ return BufferManager(self.buffer_type, buffers)
+
+ def get_info(self):
+ """Returns the information of all buffers in the manager.
+
+ Returns:
+ dict[any, dict]: Each item is a buffer name and the information
+ dict of that buffer.
+ """
+ buffer_info = {}
+ for name, buffer in self._buffers.items():
+ buffer_info[name] = {
+ 'size': buffer.qsize(),
+ 'maxsize': buffer.maxsize
+ }
+ return buffer_info
diff --git a/mmpose/apis/webcam/utils/event.py b/mmpose/apis/webcam/utils/event.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e88e1d8bdbecdeef0352cfd1ef00fd0940725a
--- /dev/null
+++ b/mmpose/apis/webcam/utils/event.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from collections import defaultdict
+from contextlib import contextmanager
+from threading import Event
+from typing import Optional
+
+logger = logging.getLogger('Event')
+
+
+class EventManager():
+ """A helper class to manage events.
+
+ :class:`EventManager` provides interfaces to register, set, clear and
+ check events by name.
+ """
+
+ def __init__(self):
+ self._events = defaultdict(Event)
+
+ def register_event(self, event_name: str, is_keyboard: bool = False):
+ """Register an event. A event must be registered first before being
+ set, cleared or checked.
+
+ Args:
+ event_name (str): The indicator of the event. The name should be
+ unique in one :class:`EventManager` instance
+ is_keyboard (bool): Specify weather it is a keyboard event. If so,
+ the ``event_name`` should be the key value, and the indicator
+ will be set as ``'_keyboard_{event_name}'``. Otherwise, the
+ ``event_name`` will be directly used as the indicator.
+ Default: ``False``
+ """
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ self._events[event_name] = Event()
+
+ def set(self, event_name: str, is_keyboard: bool = False):
+ """Set the internal flag of an event to ``True``.
+
+ Args:
+ event_name (str): The indicator of the event
+ is_keyboard (bool): Specify weather it is a keyboard event. See
+ ``register_event()`` for details. Default: False
+ """
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ self._events[event_name].set()
+ logger.info(f'Event {event_name} is set.')
+
+ def wait(self,
+ event_name: str = None,
+ is_keyboard: bool = False,
+ timeout: Optional[float] = None) -> bool:
+ """Block until the internal flag of an event is ``True``.
+
+ Args:
+ event_name (str): The indicator of the event
+ is_keyboard (bool): Specify weather it is a keyboard event. See
+ ``register_event()`` for details. Default: False
+ timeout (float, optional): The optional maximum blocking time in
+ seconds. Default: ``None``
+
+ Returns:
+ bool: The internal event flag on exit.
+ """
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].wait(timeout)
+
+ def is_set(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False) -> bool:
+ """Check weather the internal flag of an event is ``True``.
+
+ Args:
+ event_name (str): The indicator of the event
+ is_keyboard (bool): Specify weather it is a keyboard event. See
+ ``register_event()`` for details. Default: False
+ Returns:
+ bool: The internal event flag.
+ """
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ return self._events[event_name].is_set()
+
+ def clear(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False):
+ """Reset the internal flag of en event to False.
+
+ Args:
+ event_name (str): The indicator of the event
+ is_keyboard (bool): Specify weather it is a keyboard event. See
+ ``register_event()`` for details. Default: False
+ """
+ if is_keyboard:
+ event_name = self._get_keyboard_event_name(event_name)
+ self._events[event_name].clear()
+ logger.info(f'Event {event_name} is cleared.')
+
+ @staticmethod
+ def _get_keyboard_event_name(key):
+ """Get keyboard event name from the key value."""
+ return f'_keyboard_{chr(key) if isinstance(key,int) else key}'
+
+ @contextmanager
+ def wait_and_handle(self,
+ event_name: str = None,
+ is_keyboard: Optional[bool] = False):
+ """Context manager that blocks until an evenet is set ``True`` and then
+ goes into the context.
+
+ The internal event flag will be reset ``False`` automatically before
+ entering the context.
+
+ Args:
+ event_name (str): The indicator of the event
+ is_keyboard (bool): Specify weather it is a keyboard event. See
+ ``register_event()`` for details. Default: False
+
+ Example::
+ >>> from mmpose.apis.webcam.utils import EventManager
+ >>> manager = EventManager()
+ >>> manager.register_event('q', is_keybard=True)
+
+ >>> # Once the keyboard event `q` is set, ``wait_and_handle``
+ >>> # will reset the event and enter the context to invoke
+ >>> # ``foo()``
+ >>> with manager.wait_and_handle('q', is_keybard=True):
+ ... foo()
+ """
+ self.wait(event_name, is_keyboard)
+ try:
+ yield
+ finally:
+ self.clear(event_name, is_keyboard)
diff --git a/mmpose/apis/webcam/utils/image_capture.py b/mmpose/apis/webcam/utils/image_capture.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb28acff942d007b26345f5633746d6f948b9e70
--- /dev/null
+++ b/mmpose/apis/webcam/utils/image_capture.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import cv2
+import numpy as np
+
+from .misc import load_image_from_disk_or_url
+
+
+class ImageCapture:
+ """A mock-up of cv2.VideoCapture that always return a const image.
+
+ Args:
+ image (str | ndarray): The image path or image data
+ """
+
+ def __init__(self, image: Union[str, np.ndarray]):
+ if isinstance(image, str):
+ self.image = load_image_from_disk_or_url(image)
+ else:
+ self.image = image
+
+ def isOpened(self):
+ return (self.image is not None)
+
+ def read(self):
+ return True, self.image.copy()
+
+ def release(self):
+ pass
+
+ def get(self, propId):
+ if propId == cv2.CAP_PROP_FRAME_WIDTH:
+ return self.image.shape[1]
+ elif propId == cv2.CAP_PROP_FRAME_HEIGHT:
+ return self.image.shape[0]
+ elif propId == cv2.CAP_PROP_FPS:
+ return np.nan
+ else:
+ raise NotImplementedError()
diff --git a/mmpose/apis/webcam/utils/message.py b/mmpose/apis/webcam/utils/message.py
new file mode 100644
index 0000000000000000000000000000000000000000..8961ea39c29e88ebe7a42bd250656f417859e7c8
--- /dev/null
+++ b/mmpose/apis/webcam/utils/message.py
@@ -0,0 +1,186 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+import uuid
+import warnings
+from typing import Callable, Dict, List, Optional
+
+import numpy as np
+
+Filter = Callable[[Dict], bool]
+
+
+class Message():
+ """Message base class.
+
+ All message class should inherit this class. The basic use of a Message
+ instance is to carray a piece of text message (self.msg) and a dict that
+ stores structured data (self.data), e.g. frame image, model prediction,
+ et al.
+
+ A message may also hold route information, which is composed of
+ information of all nodes the message has passed through.
+
+ Parameters:
+ msg (str): The text message.
+ data (dict, optional): The structured data.
+ """
+
+ def __init__(self, msg: str = '', data: Optional[Dict] = None):
+ self.msg = msg
+ self.data = data if data else {}
+ self.route_info = []
+ self.timestamp = time.time()
+ self.id = uuid.uuid1()
+
+ def update_route_info(self,
+ node=None,
+ node_name: Optional[str] = None,
+ node_type: Optional[str] = None,
+ info: Optional[Dict] = None):
+ """Append new node information to the route information.
+
+ Args:
+ node (Node, optional): An instance of Node that provides basic
+ information like the node name and type. Default: ``None``.
+ node_name (str, optional): The node name. If node is given,
+ node_name will be ignored. Default: ``None``.
+ node_type (str, optional): The class name of the node. If node
+ is given, node_type will be ignored. Default: ``None``.
+ info (dict, optional): The node information, which is usually
+ given by node.get_node_info(). Default: ``None``.
+ """
+ if node is not None:
+ if node_name is not None or node_type is not None:
+ warnings.warn(
+ '`node_name` and `node_type` will be overridden if node '
+ 'is provided.')
+ node_name = node.name
+ node_type = node.__class__.__name__
+
+ node_info = {'node': node_name, 'node_type': node_type, 'info': info}
+ self.route_info.append(node_info)
+
+ def set_route_info(self, route_info: List[Dict]):
+ """Directly set the entire route information.
+
+ Args:
+ route_info (list): route information to set to the message.
+ """
+ self.route_info = route_info
+
+ def merge_route_info(self, route_info: List[Dict]):
+ """Merge the given route information into the original one of the
+ message. This is used for combining route information from multiple
+ messages. The node information in the route will be reordered according
+ to their timestamps.
+
+ Args:
+ route_info (list): route information to merge.
+ """
+ self.route_info += route_info
+ self.route_info.sort(key=lambda x: x.get('timestamp', np.inf))
+
+ def get_route_info(self) -> List:
+ return self.route_info.copy()
+
+
+class VideoEndingMessage(Message):
+ """The special message to indicate the ending of the input video."""
+
+
+class FrameMessage(Message):
+ """The message to store information of a video frame."""
+
+ def __init__(self, img):
+ super().__init__(data=dict(image=img, objects={}, model_cfgs={}))
+
+ def get_image(self) -> np.ndarray:
+ """Get the frame image.
+
+ Returns:
+ np.ndarray: The frame image.
+ """
+ return self.data.get('image', None)
+
+ def set_image(self, img):
+ """Set the frame image to the message.
+
+ Args:
+ img (np.ndarray): The frame image.
+ """
+ self.data['image'] = img
+
+ def set_objects(self, objects: List[Dict]):
+ """Set the object information. The old object information will be
+ cleared.
+
+ Args:
+ objects (list[dict]): A list of object information
+
+ See also :func:`update_objects`.
+ """
+ self.data['objects'] = {}
+ self.update_objects(objects)
+
+ def update_objects(self, objects: List[Dict]):
+ """Update object information.
+
+ Each object will be assigned an unique ID if it does not has one. If
+ an object's ID already exists in ``self.data['objects']``, the object
+ information will be updated; otherwise it will be added as a new
+ object.
+
+ Args:
+ objects (list[dict]): A list of object information
+ """
+ for obj in objects:
+ if '_id_' in obj:
+ # get the object id if it exists
+ obj_id = obj['_id_']
+ else:
+ # otherwise assign a new object id
+ obj_id = uuid.uuid1()
+ obj['_id_'] = obj_id
+ self.data['objects'][obj_id] = obj
+
+ def get_objects(self, obj_filter: Optional[Filter] = None) -> List[Dict]:
+ """Get object information from the frame data.
+
+ Default to return all objects in the frame data. Optionally, filters
+ can be set to retrieve objects with specific keys and values. The
+ filters are represented as a dict. Each key in the filters specifies a
+ required key of the object. Each value in the filters is a tuple that
+ enumerate the required values of the corresponding key in the object.
+
+ Args:
+ obj_filter (callable, optional): A filter function that returns a
+ bool value from a object (dict). If provided, only objects
+ that return True will be retrieved. Otherwise all objects will
+ be retrieved. Default: ``None``.
+
+ Returns:
+ list[dict]: A list of object information.
+
+
+ Example::
+ >>> objects = [
+ ... {'_id_': 2, 'label': 'dog'}
+ ... {'_id_': 1, 'label': 'cat'},
+ ... ]
+ >>> frame = FrameMessage(img)
+ >>> frame.set_objects(objects)
+ >>> frame.get_objects()
+ [
+ {'_id_': 1, 'label': 'cat'},
+ {'_id_': 2, 'label': 'dog'}
+ ]
+ >>> frame.get_objects(obj_filter=lambda x:x['label'] == 'cat')
+ [{'_id_': 1, 'label': 'cat'}]
+ """
+
+ objects = [
+ obj.copy()
+ for obj in filter(obj_filter, self.data['objects'].values())
+ ]
+
+ return objects
diff --git a/mmpose/apis/webcam/utils/misc.py b/mmpose/apis/webcam/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c6f5417aeeceb40fbe92ddb0aa8e27a9a0faadc
--- /dev/null
+++ b/mmpose/apis/webcam/utils/misc.py
@@ -0,0 +1,367 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import os.path as osp
+import sys
+import time
+from contextlib import contextmanager
+from typing import List, Optional, Tuple
+from urllib.parse import urlparse
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+from mmengine import mkdir_or_exist
+from torch.hub import HASH_REGEX, download_url_to_file
+
+
+@contextmanager
+def limit_max_fps(fps: float):
+ """A context manager to limit maximum frequence of entering the context.
+
+ Args:
+ fps (float): The maximum frequence of entering the context
+
+ Example::
+ >>> from mmpose.apis.webcam.utils import limit_max_fps
+ >>> import cv2
+
+ >>> while True:
+ ... with limit_max_fps(20):
+ ... cv2.imshow(img) # display image at most 20 fps
+ """
+ t_start = time.time()
+ try:
+ yield
+ finally:
+ t_end = time.time()
+ if fps is not None:
+ t_sleep = 1.0 / fps - t_end + t_start
+ if t_sleep > 0:
+ time.sleep(t_sleep)
+
+
+def _is_url(filename: str) -> bool:
+ """Check if the file is a url link.
+
+ Args:
+ filename (str): the file name or url link
+
+ Returns:
+ bool: is url or not.
+ """
+ prefixes = ['http://', 'https://']
+ for p in prefixes:
+ if filename.startswith(p):
+ return True
+ return False
+
+
+def load_image_from_disk_or_url(filename: str,
+ readFlag: int = cv2.IMREAD_COLOR
+ ) -> np.ndarray:
+ """Load an image file, from disk or url.
+
+ Args:
+ filename (str): file name on the disk or url link
+ readFlag (int): readFlag for imdecode. Default: cv2.IMREAD_COLOR
+
+ Returns:
+ np.ndarray: A loaded image
+ """
+ if _is_url(filename):
+ # download the image, convert it to a NumPy array, and then read
+ # it into OpenCV format
+ resp = urlopen(filename)
+ image = np.asarray(bytearray(resp.read()), dtype='uint8')
+ image = cv2.imdecode(image, readFlag)
+ return image
+ else:
+ image = cv2.imread(filename, readFlag)
+ return image
+
+
+def get_cached_file_path(url: str,
+ save_dir: str,
+ progress: bool = True,
+ check_hash: bool = False,
+ file_name: Optional[str] = None) -> str:
+ r"""Loads the Torch serialized object at the given URL.
+
+ If downloaded file is a zip file, it will be automatically decompressed
+
+ If the object is already present in `model_dir`, it's deserialized and
+ returned.
+ The default value of ``model_dir`` is ``/checkpoints`` where
+ ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
+
+ Args:
+ url (str): URL of the object to download
+ save_dir (str): directory in which to save the object
+ progress (bool): whether or not to display a progress bar
+ to stderr. Default: ``True``
+ check_hash(bool): If True, the filename part of the URL
+ should follow the naming convention ``filename-.ext``
+ where ```` is the first eight or more digits of the
+ SHA256 hash of the contents of the file. The hash is used to
+ ensure unique names and to verify the contents of the file.
+ Default: ``False``
+ file_name (str, optional): name for the downloaded file. Filename
+ from ``url`` will be used if not set. Default: ``None``.
+
+ Returns:
+ str: The path to the cached file.
+ """
+
+ mkdir_or_exist(save_dir)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.join(save_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ hash_prefix = None
+ if check_hash:
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+ return cached_file
+
+
+def screen_matting(img: np.ndarray,
+ color_low: Optional[Tuple] = None,
+ color_high: Optional[Tuple] = None,
+ color: Optional[str] = None) -> np.ndarray:
+ """Get screen matting mask.
+
+ Args:
+ img (np.ndarray): Image data.
+ color_low (tuple): Lower limit (b, g, r).
+ color_high (tuple): Higher limit (b, g, r).
+ color (str): Support colors include:
+
+ - 'green' or 'g'
+ - 'blue' or 'b'
+ - 'black' or 'k'
+ - 'white' or 'w'
+
+ Returns:
+ np.ndarray: A mask with the same shape of the input image. The value
+ is 0 at the pixels in the matting color range, and 1 everywhere else.
+ """
+
+ if color_high is None or color_low is None:
+ if color is not None:
+ if color.lower() == 'g' or color.lower() == 'green':
+ color_low = (0, 200, 0)
+ color_high = (60, 255, 60)
+ elif color.lower() == 'b' or color.lower() == 'blue':
+ color_low = (230, 0, 0)
+ color_high = (255, 40, 40)
+ elif color.lower() == 'k' or color.lower() == 'black':
+ color_low = (0, 0, 0)
+ color_high = (40, 40, 40)
+ elif color.lower() == 'w' or color.lower() == 'white':
+ color_low = (230, 230, 230)
+ color_high = (255, 255, 255)
+ else:
+ raise NotImplementedError(f'Not supported color: {color}.')
+ else:
+ raise ValueError(
+ 'color or color_high | color_low should be given.')
+
+ mask = cv2.inRange(img, np.array(color_low), np.array(color_high)) == 0
+
+ return mask.astype(np.uint8)
+
+
+def expand_and_clamp(box: List, im_shape: Tuple, scale: float = 1.25) -> List:
+ """Expand the bbox and clip it to fit the image shape.
+
+ Args:
+ box (list): x1, y1, x2, y2
+ im_shape (tuple): image shape (h, w, c)
+ scale (float): expand ratio
+
+ Returns:
+ list: x1, y1, x2, y2
+ """
+
+ x1, y1, x2, y2 = box[:4]
+ w = x2 - x1
+ h = y2 - y1
+ deta_w = w * (scale - 1) / 2
+ deta_h = h * (scale - 1) / 2
+
+ x1, y1, x2, y2 = x1 - deta_w, y1 - deta_h, x2 + deta_w, y2 + deta_h
+
+ img_h, img_w = im_shape[:2]
+
+ x1 = min(max(0, int(x1)), img_w - 1)
+ y1 = min(max(0, int(y1)), img_h - 1)
+ x2 = min(max(0, int(x2)), img_w - 1)
+ y2 = min(max(0, int(y2)), img_h - 1)
+
+ return [x1, y1, x2, y2]
+
+
+def _find_bbox(mask):
+ """Find the bounding box for the mask.
+
+ Args:
+ mask (ndarray): Mask.
+
+ Returns:
+ list(4, ): Returned box (x1, y1, x2, y2).
+ """
+ mask_shape = mask.shape
+ if len(mask_shape) == 3:
+ assert mask_shape[-1] == 1, 'the channel of the mask should be 1.'
+ elif len(mask_shape) == 2:
+ pass
+ else:
+ NotImplementedError()
+
+ h, w = mask_shape[:2]
+ mask_w = mask.sum(0)
+ mask_h = mask.sum(1)
+
+ left = 0
+ right = w - 1
+ up = 0
+ down = h - 1
+
+ for i in range(w):
+ if mask_w[i] > 0:
+ break
+ left += 1
+
+ for i in range(w - 1, left, -1):
+ if mask_w[i] > 0:
+ break
+ right -= 1
+
+ for i in range(h):
+ if mask_h[i] > 0:
+ break
+ up += 1
+
+ for i in range(h - 1, up, -1):
+ if mask_h[i] > 0:
+ break
+ down -= 1
+
+ return [left, up, right, down]
+
+
+def copy_and_paste(
+ img: np.ndarray,
+ background_img: np.ndarray,
+ mask: np.ndarray,
+ bbox: Optional[List] = None,
+ effect_region: Tuple = (0.2, 0.2, 0.8, 0.8),
+ min_size: Tuple = (20, 20)
+) -> np.ndarray:
+ """Copy the image region and paste to the background.
+
+ Args:
+ img (np.ndarray): Image data.
+ background_img (np.ndarray): Background image data.
+ mask (ndarray): instance segmentation result.
+ bbox (list, optional): instance bbox in (x1, y1, x2, y2). If not
+ given, the bbox will be obtained by ``_find_bbox()``. Default:
+ ``None``
+ effect_region (tuple): The region to apply mask, the coordinates
+ are normalized (x1, y1, x2, y2). Default: (0.2, 0.2, 0.8, 0.8)
+ min_size (tuple): The minimum region size (w, h) in pixels.
+ Default: (20, 20)
+
+ Returns:
+ np.ndarray: The background with pasted image region.
+ """
+ background_img = background_img.copy()
+ background_h, background_w = background_img.shape[:2]
+ region_h = (effect_region[3] - effect_region[1]) * background_h
+ region_w = (effect_region[2] - effect_region[0]) * background_w
+ region_aspect_ratio = region_w / region_h
+
+ if bbox is None:
+ bbox = _find_bbox(mask)
+ instance_w = bbox[2] - bbox[0]
+ instance_h = bbox[3] - bbox[1]
+
+ if instance_w > min_size[0] and instance_h > min_size[1]:
+ aspect_ratio = instance_w / instance_h
+ if region_aspect_ratio > aspect_ratio:
+ resize_rate = region_h / instance_h
+ else:
+ resize_rate = region_w / instance_w
+
+ mask_inst = mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
+ img_inst = img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
+ img_inst = cv2.resize(
+ img_inst.astype('float32'),
+ (int(resize_rate * instance_w), int(resize_rate * instance_h)))
+ img_inst = img_inst.astype(background_img.dtype)
+ mask_inst = cv2.resize(
+ mask_inst.astype('float32'),
+ (int(resize_rate * instance_w), int(resize_rate * instance_h)),
+ interpolation=cv2.INTER_NEAREST)
+
+ mask_ids = list(np.where(mask_inst == 1))
+ mask_ids[1] += int(effect_region[0] * background_w)
+ mask_ids[0] += int(effect_region[1] * background_h)
+
+ background_img[tuple(mask_ids)] = img_inst[np.where(mask_inst == 1)]
+
+ return background_img
+
+
+def is_image_file(path: str) -> bool:
+ """Check if a path is an image file by its extension.
+
+ Args:
+ path (str): The image path.
+
+ Returns:
+ bool: Weather the path is an image file.
+ """
+ if isinstance(path, str):
+ if path.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')):
+ return True
+ return False
+
+
+def get_config_path(path: str, module_name: str):
+ """Get config path from an OpenMMLab codebase.
+
+ If the path is an existing file, it will be directly returned. If the file
+ doesn't exist, it will be searched in the 'configs' folder of the
+ specified module.
+
+ Args:
+ path (str): the path of the config file
+ module_name (str): The module name of an OpenMMLab codebase
+
+ Returns:
+ str: The config file path.
+
+ Example::
+ >>> path = 'configs/_base_/filters/one_euro.py'
+ >>> get_config_path(path, 'mmpose')
+ '/home/mmpose/configs/_base_/filters/one_euro.py'
+ """
+
+ if osp.isfile(path):
+ return path
+
+ module = importlib.import_module(module_name)
+ module_dir = osp.dirname(module.__file__)
+ path_in_module = osp.join(module_dir, '.mim', path)
+
+ if not osp.isfile(path_in_module):
+ raise FileNotFoundError(f'Can not find the config file "{path}"')
+
+ return path_in_module
diff --git a/mmpose/apis/webcam/utils/pose.py b/mmpose/apis/webcam/utils/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ff32f9e169ad4070f7ffb8d1f2664dfe50b8013
--- /dev/null
+++ b/mmpose/apis/webcam/utils/pose.py
@@ -0,0 +1,181 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Tuple
+
+
+def get_eye_keypoint_ids(dataset_meta: Dict) -> Tuple[int, int]:
+ """A helper function to get the keypoint indices of left and right eyes
+ from the dataset meta information.
+
+ Args:
+ dataset_meta (dict): dataset meta information.
+
+ Returns:
+ tuple[int, int]: The keypoint indices of left eye and right eye.
+ """
+ left_eye_idx = None
+ right_eye_idx = None
+
+ # try obtaining eye point ids from dataset_meta
+ keypoint_name2id = dataset_meta.get('keypoint_name2id', {})
+ left_eye_idx = keypoint_name2id.get('left_eye', None)
+ right_eye_idx = keypoint_name2id.get('right_eye', None)
+
+ if left_eye_idx is None or right_eye_idx is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = dataset_meta.get('dataset_name', 'unknown dataset')
+ if dataset_name in {'coco', 'coco_wholebody'}:
+ left_eye_idx = 1
+ right_eye_idx = 2
+ elif dataset_name in {'animalpose', 'ap10k'}:
+ left_eye_idx = 0
+ right_eye_idx = 1
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return left_eye_idx, right_eye_idx
+
+
+def get_face_keypoint_ids(dataset_meta: Dict) -> List:
+ """A helper function to get the keypoint indices of the face from the
+ dataset meta information.
+
+ Args:
+ dataset_meta (dict): dataset meta information.
+
+ Returns:
+ list[int]: face keypoint indices. The length depends on the dataset.
+ """
+ face_indices = []
+
+ # try obtaining nose point ids from dataset_meta
+ keypoint_name2id = dataset_meta.get('keypoint_name2id', {})
+ for id in range(68):
+ face_indices.append(keypoint_name2id.get(f'face-{id}', None))
+
+ if None in face_indices:
+ # Fall back to hard coded keypoint id
+ dataset_name = dataset_meta.get('dataset_name', 'unknown dataset')
+ if dataset_name in {'coco_wholebody'}:
+ face_indices = list(range(23, 91))
+ else:
+ raise ValueError('Can not determine the face id of '
+ f'{dataset_name}')
+
+ return face_indices
+
+
+def get_wrist_keypoint_ids(dataset_meta: Dict) -> Tuple[int, int]:
+ """A helper function to get the keypoint indices of left and right wrists
+ from the dataset meta information.
+
+ Args:
+ dataset_meta (dict): dataset meta information.
+ Returns:
+ tuple[int, int]: The keypoint indices of left and right wrists.
+ """
+
+ # try obtaining wrist point ids from dataset_meta
+ keypoint_name2id = dataset_meta.get('keypoint_name2id', {})
+ left_wrist_idx = keypoint_name2id.get('left_wrist', None)
+ right_wrist_idx = keypoint_name2id.get('right_wrist', None)
+
+ if left_wrist_idx is None or right_wrist_idx is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = dataset_meta.get('dataset_name', 'unknown dataset')
+ if dataset_name in {'coco', 'coco_wholebody'}:
+ left_wrist_idx = 9
+ right_wrist_idx = 10
+ elif dataset_name == 'animalpose':
+ left_wrist_idx = 16
+ right_wrist_idx = 17
+ elif dataset_name == 'ap10k':
+ left_wrist_idx = 7
+ right_wrist_idx = 10
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return left_wrist_idx, right_wrist_idx
+
+
+def get_mouth_keypoint_ids(dataset_meta: Dict) -> int:
+ """A helper function to get the mouth keypoint index from the dataset meta
+ information.
+
+ Args:
+ dataset_meta (dict): dataset meta information.
+ Returns:
+ int: The mouth keypoint index
+ """
+ # try obtaining mouth point ids from dataset_info
+ keypoint_name2id = dataset_meta.get('keypoint_name2id', {})
+ mouth_index = keypoint_name2id.get('face-62', None)
+
+ if mouth_index is None:
+ # Fall back to hard coded keypoint id
+ dataset_name = dataset_meta.get('dataset_name', 'unknown dataset')
+ if dataset_name == 'coco_wholebody':
+ mouth_index = 85
+ else:
+ raise ValueError('Can not determine the eye keypoint id of '
+ f'{dataset_name}')
+
+ return mouth_index
+
+
+def get_hand_keypoint_ids(dataset_meta: Dict) -> List[int]:
+ """A helper function to get the keypoint indices of left and right hand
+ from the dataset meta information.
+
+ Args:
+ dataset_meta (dict): dataset meta information.
+ Returns:
+ list[int]: hand keypoint indices. The length depends on the dataset.
+ """
+ # try obtaining hand keypoint ids from dataset_meta
+ keypoint_name2id = dataset_meta.get('keypoint_name2id', {})
+ hand_indices = []
+ hand_indices.append(keypoint_name2id.get('left_hand_root', None))
+
+ for id in range(1, 5):
+ hand_indices.append(keypoint_name2id.get(f'left_thumb{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(keypoint_name2id.get(f'left_forefinger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'left_middle_finger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'left_ring_finger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'left_pinky_finger{id}', None))
+
+ hand_indices.append(keypoint_name2id.get('right_hand_root', None))
+
+ for id in range(1, 5):
+ hand_indices.append(keypoint_name2id.get(f'right_thumb{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'right_forefinger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'right_middle_finger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'right_ring_finger{id}', None))
+ for id in range(1, 5):
+ hand_indices.append(
+ keypoint_name2id.get(f'right_pinky_finger{id}', None))
+
+ if None in hand_indices:
+ # Fall back to hard coded keypoint id
+ dataset_name = dataset_meta.get('dataset_name', 'unknown dataset')
+ if dataset_name in {'coco_wholebody'}:
+ hand_indices = list(range(91, 133))
+ else:
+ raise ValueError('Can not determine the hand id of '
+ f'{dataset_name}')
+
+ return hand_indices
diff --git a/mmpose/apis/webcam/webcam_executor.py b/mmpose/apis/webcam/webcam_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f39aa4b84710c69064fe09e8cf0a3e4db8b006a2
--- /dev/null
+++ b/mmpose/apis/webcam/webcam_executor.py
@@ -0,0 +1,329 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+import sys
+import time
+import warnings
+from threading import Thread
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+
+from .nodes import NODES
+from .utils import (BufferManager, EventManager, FrameMessage, ImageCapture,
+ VideoEndingMessage, is_image_file, limit_max_fps)
+
+try:
+ from contextlib import nullcontext
+except ImportError:
+ # compatible with python3.6
+ from contextlib import contextmanager
+
+ @contextmanager
+ def nullcontext(enter_result=None):
+ yield enter_result
+
+
+DEFAULT_FRAME_BUFFER_SIZE = 1
+DEFAULT_INPUT_BUFFER_SIZE = 1
+DEFAULT_DISPLAY_BUFFER_SIZE = 0
+DEFAULT_USER_BUFFER_SIZE = 1
+
+logger = logging.getLogger('Executor')
+
+
+class WebcamExecutor():
+ """The interface to build and execute webcam applications from configs.
+
+ Parameters:
+ nodes (list[dict]): Node configs. See :class:`webcam.nodes.Node` for
+ details
+ name (str): Executor name. Default: 'MMPose Webcam App'.
+ camera_id (int | str): The camera ID (usually the ID of the default
+ camera is 0). Alternatively a file path or a URL can be given
+ to load from a video or image file.
+ camera_frame_shape (tuple, optional): Set the frame shape of the
+ camera in (width, height). If not given, the default frame shape
+ will be used. This argument is only valid when using a camera
+ as the input source. Default: ``None``
+ camera_max_fps (int): Video reading maximum FPS. Default: 30
+ buffer_sizes (dict, optional): A dict to specify buffer sizes. The
+ key is the buffer name and the value is the buffer size.
+ Default: ``None``
+
+ Example::
+ >>> cfg = dict(
+ >>> name='Test Webcam',
+ >>> camera_id=0,
+ >>> camera_max_fps=30,
+ >>> nodes=[
+ >>> dict(
+ >>> type='MonitorNode',
+ >>> name='monitor',
+ >>> enable_key='m',
+ >>> enable=False,
+ >>> input_buffer='_frame_',
+ >>> output_buffer='display'),
+ >>> dict(
+ >>> type='RecorderNode',
+ >>> name='recorder',
+ >>> out_video_file='webcam_output.mp4',
+ >>> input_buffer='display',
+ >>> output_buffer='_display_')
+ >>> ])
+
+ >>> executor = WebcamExecutor(**cfg)
+ """
+
+ def __init__(self,
+ nodes: List[Dict],
+ name: str = 'MMPose Webcam App',
+ camera_id: Union[int, str] = 0,
+ camera_max_fps: int = 30,
+ camera_frame_shape: Optional[Tuple[int, int]] = None,
+ synchronous: bool = False,
+ buffer_sizes: Optional[Dict[str, int]] = None):
+
+ # Basic parameters
+ self.name = name
+ self.camera_id = camera_id
+ self.camera_max_fps = camera_max_fps
+ self.camera_frame_shape = camera_frame_shape
+ self.synchronous = synchronous
+
+ # self.buffer_manager manages data flow between executor and nodes
+ self.buffer_manager = BufferManager()
+ # self.event_manager manages event-based asynchronous communication
+ self.event_manager = EventManager()
+ # self.node_list holds all node instance
+ self.node_list = []
+ # self.vcap is used to read camera frames. It will be built when the
+ # executor starts running
+ self.vcap = None
+
+ # Register executor events
+ self.event_manager.register_event('_exit_', is_keyboard=False)
+ if self.synchronous:
+ self.event_manager.register_event('_idle_', is_keyboard=False)
+
+ # Register nodes
+ if not nodes:
+ raise ValueError('No node is registered to the executor.')
+
+ # Register default buffers
+ if buffer_sizes is None:
+ buffer_sizes = {}
+ # _frame_ buffer
+ frame_buffer_size = buffer_sizes.get('_frame_',
+ DEFAULT_FRAME_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_frame_', frame_buffer_size)
+ # _input_ buffer
+ input_buffer_size = buffer_sizes.get('_input_',
+ DEFAULT_INPUT_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_input_', input_buffer_size)
+ # _display_ buffer
+ display_buffer_size = buffer_sizes.get('_display_',
+ DEFAULT_DISPLAY_BUFFER_SIZE)
+ self.buffer_manager.register_buffer('_display_', display_buffer_size)
+
+ # Build all nodes:
+ for node_cfg in nodes:
+ logger.info(f'Create node: {node_cfg.name}({node_cfg.type})')
+ node = NODES.build(node_cfg)
+
+ # Register node
+ self.node_list.append(node)
+
+ # Register buffers
+ for buffer_info in node.registered_buffers:
+ buffer_name = buffer_info.buffer_name
+ if buffer_name in self.buffer_manager:
+ continue
+ buffer_size = buffer_sizes.get(buffer_name,
+ DEFAULT_USER_BUFFER_SIZE)
+ self.buffer_manager.register_buffer(buffer_name, buffer_size)
+ logger.info(
+ f'Register user buffer: {buffer_name}({buffer_size})')
+
+ # Register events
+ for event_info in node.registered_events:
+ self.event_manager.register_event(
+ event_name=event_info.event_name,
+ is_keyboard=event_info.is_keyboard)
+ logger.info(f'Register event: {event_info.event_name}')
+
+ # Set executor for nodes
+ # This step is performed after node building when the executor has
+ # create full buffer/event managers and can
+ for node in self.node_list:
+ logger.info(f'Set executor for node: {node.name})')
+ node.set_executor(self)
+
+ def _read_camera(self):
+ """Read video frames from the caemra (or the source video/image) and
+ put them into input buffers."""
+
+ camera_id = self.camera_id
+ fps = self.camera_max_fps
+
+ # Build video capture
+ if is_image_file(camera_id):
+ self.vcap = ImageCapture(camera_id)
+ else:
+ self.vcap = cv2.VideoCapture(camera_id)
+ if self.camera_frame_shape is not None:
+ width, height = self.camera_frame_shape
+ self.vcap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+ self.vcap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+
+ if not self.vcap.isOpened():
+ warnings.warn(f'Cannot open camera (ID={camera_id})')
+ sys.exit()
+
+ # Read video frames in a loop
+ first_frame = True
+ while not self.event_manager.is_set('_exit_'):
+ if self.synchronous:
+ if first_frame:
+ cm = nullcontext()
+ else:
+ # Read a new frame until the last frame has been processed
+ cm = self.event_manager.wait_and_handle('_idle_')
+ else:
+ # Read frames with a maximum FPS
+ cm = limit_max_fps(fps)
+
+ first_frame = False
+
+ with cm:
+ # Read a frame
+ ret_val, frame = self.vcap.read()
+ if ret_val:
+ # Put frame message (for display) into buffer `_frame_`
+ frame_msg = FrameMessage(frame)
+ self.buffer_manager.put('_frame_', frame_msg)
+
+ # Put input message (for model inference or other use)
+ # into buffer `_input_`
+ input_msg = FrameMessage(frame.copy())
+ input_msg.update_route_info(
+ node_name='Camera Info',
+ node_type='none',
+ info=self._get_camera_info())
+ self.buffer_manager.put_force('_input_', input_msg)
+ logger.info('Read one frame.')
+ else:
+ logger.info('Reached the end of the video.')
+ # Put a video ending signal
+ self.buffer_manager.put_force('_frame_',
+ VideoEndingMessage())
+ self.buffer_manager.put_force('_input_',
+ VideoEndingMessage())
+ # Wait for `_exit_` event util a timeout occurs
+ if not self.event_manager.wait('_exit_', timeout=5.0):
+ break
+
+ self.vcap.release()
+
+ def _display(self):
+ """Receive processed frames from the output buffer and display on
+ screen."""
+
+ output_msg = None
+
+ while not self.event_manager.is_set('_exit_'):
+ while self.buffer_manager.is_empty('_display_'):
+ time.sleep(0.001)
+
+ # Set _idle_ to allow reading next frame
+ if self.synchronous:
+ self.event_manager.set('_idle_')
+
+ # acquire output from buffer
+ output_msg = self.buffer_manager.get('_display_')
+
+ # None indicates input stream ends
+ if isinstance(output_msg, VideoEndingMessage):
+ self.event_manager.set('_exit_')
+ break
+
+ img = output_msg.get_image()
+
+ # show in a window
+ cv2.imshow(self.name, img)
+
+ # handle keyboard input
+ key = cv2.waitKey(1)
+ if key != -1:
+ self._on_keyboard_input(key)
+
+ cv2.destroyAllWindows()
+
+ # Avoid dead lock
+ if self.synchronous:
+ self.event_manager.set('_idle_')
+
+ def _on_keyboard_input(self, key):
+ """Handle the keyboard input.
+
+ The key 'Q' and `ESC` will trigger an '_exit_' event, which will be
+ responded by all nodes and the executor itself to exit. Other keys will
+ trigger keyboard event to be responded by the nodes which has
+ registered corresponding event. See :class:`webcam.utils.EventManager`
+ for details.
+ """
+
+ if key in (27, ord('q'), ord('Q')):
+ logger.info(f'Exit event captured: {key}')
+ self.event_manager.set('_exit_')
+ else:
+ logger.info(f'Keyboard event captured: {key}')
+ self.event_manager.set(key, is_keyboard=True)
+
+ def _get_camera_info(self):
+ """Return the camera information in a dict."""
+
+ frame_width = self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH)
+ frame_height = self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ frame_rate = self.vcap.get(cv2.CAP_PROP_FPS)
+
+ cam_info = {
+ 'Camera ID': self.camera_id,
+ 'Camera resolution': f'{frame_width}x{frame_height}',
+ 'Camera FPS': frame_rate,
+ }
+
+ return cam_info
+
+ def run(self):
+ """Start the executor.
+
+ This method starts all nodes as well as video I/O in separate threads.
+ """
+
+ try:
+ # Start node threads
+ non_daemon_nodes = []
+ for node in self.node_list:
+ node.start()
+ if not node.daemon:
+ non_daemon_nodes.append(node)
+
+ # Create a thread to read video frames
+ t_read = Thread(target=self._read_camera, args=())
+ t_read.start()
+
+ # Run display in the main thread
+ self._display()
+ logger.info('Display has stopped.')
+
+ # joint non-daemon nodes and executor threads
+ logger.info('Camera reading is about to join.')
+ t_read.join()
+
+ for node in non_daemon_nodes:
+ logger.info(f'Node {node.name} is about to join.')
+ node.join()
+ logger.info('All nodes jointed successfully.')
+
+ except KeyboardInterrupt:
+ pass
diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a88ebac70169321c248a70346b1db75c7908e58b
--- /dev/null
+++ b/mmpose/codecs/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .associative_embedding import AssociativeEmbedding
+from .decoupled_heatmap import DecoupledHeatmap
+from .integral_regression_label import IntegralRegressionLabel
+from .megvii_heatmap import MegviiHeatmap
+from .msra_heatmap import MSRAHeatmap
+from .regression_label import RegressionLabel
+from .simcc_label import SimCCLabel
+from .spr import SPR
+from .udp_heatmap import UDPHeatmap
+
+__all__ = [
+ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
+ 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR',
+ 'DecoupledHeatmap'
+]
diff --git a/mmpose/codecs/__pycache__/__init__.cpython-38.pyc b/mmpose/codecs/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae501df82620881a4b9e83192275219f68e9762d
Binary files /dev/null and b/mmpose/codecs/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc b/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e407ecb8baf13482acb72c057037f24eee344c34
Binary files /dev/null and b/mmpose/codecs/__pycache__/associative_embedding.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/base.cpython-38.pyc b/mmpose/codecs/__pycache__/base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef07cf8a623136b02c2175bb4e01d33286dc98f6
Binary files /dev/null and b/mmpose/codecs/__pycache__/base.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..094eb04b5366ee46d1d113f276cd89abee8c877c
Binary files /dev/null and b/mmpose/codecs/__pycache__/decoupled_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc b/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea5d7c364b100dc30ccd0aad366f8868a6c97750
Binary files /dev/null and b/mmpose/codecs/__pycache__/integral_regression_label.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3128624daa262bd0a8ac6acad71d969b0ceb4978
Binary files /dev/null and b/mmpose/codecs/__pycache__/megvii_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc63009b8f820cd14f0cdb09f5325b8d5b0da309
Binary files /dev/null and b/mmpose/codecs/__pycache__/msra_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc b/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab1cd72c23e64fbff523ecb6491de1c6ebed1800
Binary files /dev/null and b/mmpose/codecs/__pycache__/regression_label.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc b/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5de9cb61615685f4f32d944eed93a5219c87a1a
Binary files /dev/null and b/mmpose/codecs/__pycache__/simcc_label.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/spr.cpython-38.pyc b/mmpose/codecs/__pycache__/spr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ec5689cb7da25c6d34dafee423500d1f5da486f
Binary files /dev/null and b/mmpose/codecs/__pycache__/spr.cpython-38.pyc differ
diff --git a/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc b/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbc7726fc0f90d16bade148624b9cbdad70f0cb0
Binary files /dev/null and b/mmpose/codecs/__pycache__/udp_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e080f1657d17deabf6e44bea275432216d621e5
--- /dev/null
+++ b/mmpose/codecs/associative_embedding.py
@@ -0,0 +1,512 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import namedtuple
+from itertools import product
+from typing import Any, List, Optional, Tuple
+
+import numpy as np
+import torch
+from munkres import Munkres
+from torch import Tensor
+
+from mmpose.registry import KEYPOINT_CODECS
+from mmpose.utils.tensor_utils import to_numpy
+from .base import BaseKeypointCodec
+from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps,
+ generate_udp_gaussian_heatmaps, refine_keypoints,
+ refine_keypoints_dark_udp)
+
+
+def _group_keypoints_by_tags(vals: np.ndarray,
+ tags: np.ndarray,
+ locs: np.ndarray,
+ keypoint_order: List[int],
+ val_thr: float,
+ tag_thr: float = 1.0,
+ max_groups: Optional[int] = None) -> np.ndarray:
+ """Group the keypoints by tags using Munkres algorithm.
+
+ Note:
+
+ - keypoint number: K
+ - candidate number: M
+ - tag dimenssion: L
+ - coordinate dimension: D
+ - group number: G
+
+ Args:
+ vals (np.ndarray): The heatmap response values of keypoints in shape
+ (K, M)
+ tags (np.ndarray): The tags of the keypoint candidates in shape
+ (K, M, L)
+ locs (np.ndarray): The locations of the keypoint candidates in shape
+ (K, M, D)
+ keypoint_order (List[int]): The grouping order of the keypoints.
+ The groupping usually starts from a keypoints around the head and
+ torso, and gruadually moves out to the limbs
+ val_thr (float): The threshold of the keypoint response value
+ tag_thr (float): The maximum allowed tag distance when matching a
+ keypoint to a group. A keypoint with larger tag distance to any
+ of the existing groups will initializes a new group
+ max_groups (int, optional): The maximum group number. ``None`` means
+ no limitation. Defaults to ``None``
+
+ Returns:
+ np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
+ dimenssion is the concatenated keypoint coordinates and scores.
+ """
+ K, M, D = locs.shape
+ assert vals.shape == tags.shape[:2] == (K, M)
+ assert len(keypoint_order) == K
+
+ # Build Munkres instance
+ munkres = Munkres()
+
+ # Build a group pool, each group contains the keypoints of an instance
+ groups = []
+
+ Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list'])
+
+ def _init_group():
+ """Initialize a group, which is composed of the keypoints, keypoint
+ scores and the tag of each keypoint."""
+ _group = Group(
+ kpts=np.zeros((K, D), dtype=np.float32),
+ scores=np.zeros(K, dtype=np.float32),
+ tag_list=[])
+ return _group
+
+ for i in keypoint_order:
+ # Get all valid candidate of the i-th keypoints
+ valid = vals[i] > val_thr
+ if not valid.any():
+ continue
+
+ tags_i = tags[i, valid] # (M', L)
+ vals_i = vals[i, valid] # (M',)
+ locs_i = locs[i, valid] # (M', D)
+
+ if len(groups) == 0: # Initialize the group pool
+ for tag, val, loc in zip(tags_i, vals_i, locs_i):
+ group = _init_group()
+ group.kpts[i] = loc
+ group.scores[i] = val
+ group.tag_list.append(tag)
+
+ groups.append(group)
+
+ else: # Match keypoints to existing groups
+ groups = groups[:max_groups]
+ group_tags = [np.mean(g.tag_list, axis=0) for g in groups]
+
+ # Calculate distance matrix between group tags and tag candidates
+ # of the i-th keypoint
+ # Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
+ diff = tags_i[:, None] - np.array(group_tags)[None]
+ dists = np.linalg.norm(diff, ord=2, axis=2)
+ num_kpts, num_groups = dists.shape[:2]
+
+ # Experimental cost function for keypoint-group matching
+ costs = np.round(dists) * 100 - vals_i[..., None]
+ if num_kpts > num_groups:
+ padding = np.full((num_kpts, num_kpts - num_groups),
+ 1e10,
+ dtype=np.float32)
+ costs = np.concatenate((costs, padding), axis=1)
+
+ # Match keypoints and groups by Munkres algorithm
+ matches = munkres.compute(costs)
+ for kpt_idx, group_idx in matches:
+ if group_idx < num_groups and dists[kpt_idx,
+ group_idx] < tag_thr:
+ # Add the keypoint to the matched group
+ group = groups[group_idx]
+ else:
+ # Initialize a new group with unmatched keypoint
+ group = _init_group()
+ groups.append(group)
+
+ group.kpts[i] = locs_i[kpt_idx]
+ group.scores[i] = vals_i[kpt_idx]
+ group.tag_list.append(tags_i[kpt_idx])
+
+ groups = groups[:max_groups]
+ if groups:
+ grouped_keypoints = np.stack(
+ [np.r_['1', g.kpts, g.scores[:, None]] for g in groups])
+ else:
+ grouped_keypoints = np.empty((0, K, D + 1))
+
+ return grouped_keypoints
+
+
+@KEYPOINT_CODECS.register_module()
+class AssociativeEmbedding(BaseKeypointCodec):
+ """Encode/decode keypoints with the method introduced in "Associative
+ Embedding". This is an asymmetric codec, where the keypoints are
+ represented as gaussian heatmaps and position indices during encoding, and
+ restored from predicted heatmaps and group tags.
+
+ See the paper `Associative Embedding: End-to-End Learning for Joint
+ Detection and Grouping`_ by Newell et al (2017) for details
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - embedding tag dimension: L
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
+ where [W, H] is the `heatmap_size`
+ - keypoint_indices (np.ndarray): The keypoint position indices in shape
+ (N, K, 2). Each keypoint's index is [i, v], where i is the position
+ index in the heatmap (:math:`i=y*w+x`) and v is the visibility
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ sigma (float): The sigma value of the Gaussian heatmap
+ use_udp (bool): Whether use unbiased data processing. See
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
+ decode_keypoint_order (List[int]): The grouping order of the
+ keypoint indices. The groupping usually starts from a keypoints
+ around the head and torso, and gruadually moves out to the limbs
+ decode_keypoint_thr (float): The threshold of keypoint response value
+ in heatmaps. Defaults to 0.1
+ decode_tag_thr (float): The maximum allowed tag distance when matching
+ a keypoint to a group. A keypoint with larger tag distance to any
+ of the existing groups will initializes a new group. Defaults to
+ 1.0
+ decode_nms_kernel (int): The kernel size of the NMS during decoding,
+ which should be an odd integer. Defaults to 5
+ decode_gaussian_kernel (int): The kernel size of the Gaussian blur
+ during decoding, which should be an odd integer. It is only used
+ when ``self.use_udp==True``. Defaults to 3
+ decode_topk (int): The number top-k candidates of each keypoints that
+ will be retrieved from the heatmaps during dedocding. Defaults to
+ 20
+ decode_max_instances (int, optional): The maximum number of instances
+ to decode. ``None`` means no limitation to the instance number.
+ Defaults to ``None``
+
+ .. _`Associative Embedding: End-to-End Learning for Joint Detection and
+ Grouping`: https://arxiv.org/abs/1611.05424
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
+ """
+
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ sigma: Optional[float] = None,
+ use_udp: bool = False,
+ decode_keypoint_order: List[int] = [],
+ decode_nms_kernel: int = 5,
+ decode_gaussian_kernel: int = 3,
+ decode_keypoint_thr: float = 0.1,
+ decode_tag_thr: float = 1.0,
+ decode_topk: int = 20,
+ decode_max_instances: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.use_udp = use_udp
+ self.decode_nms_kernel = decode_nms_kernel
+ self.decode_gaussian_kernel = decode_gaussian_kernel
+ self.decode_keypoint_thr = decode_keypoint_thr
+ self.decode_tag_thr = decode_tag_thr
+ self.decode_topk = decode_topk
+ self.decode_max_instances = decode_max_instances
+ self.dedecode_keypoint_order = decode_keypoint_order.copy()
+
+ if self.use_udp:
+ self.scale_factor = ((np.array(input_size) - 1) /
+ (np.array(heatmap_size) - 1)).astype(
+ np.float32)
+ else:
+ self.scale_factor = (np.array(input_size) /
+ heatmap_size).astype(np.float32)
+
+ if sigma is None:
+ sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64
+ self.sigma = sigma
+
+ def encode(
+ self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Encode keypoints into heatmaps and position indices. Note that the
+ original keypoint coordinates should be in the input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_indices (np.ndarray): The keypoint position indices
+ in shape (N, K, 2). Each keypoint's index is [i, v], where i
+ is the position index in the heatmap (:math:`i=y*w+x`) and v
+ is the visibility
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ # keypoint coordinates in heatmap
+ _keypoints = keypoints / self.scale_factor
+
+ if self.use_udp:
+ heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=_keypoints,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma)
+ else:
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=_keypoints,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma)
+
+ keypoint_indices = self._encode_keypoint_indices(
+ heatmap_size=self.heatmap_size,
+ keypoints=_keypoints,
+ keypoints_visible=keypoints_visible)
+
+ encoded = dict(
+ heatmaps=heatmaps,
+ keypoint_indices=keypoint_indices,
+ keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray) -> np.ndarray:
+ w, h = heatmap_size
+ N, K, _ = keypoints.shape
+ keypoint_indices = np.zeros((N, K, 2), dtype=np.int64)
+
+ for n, k in product(range(N), range(K)):
+ x, y = (keypoints[n, k] + 0.5).astype(np.int64)
+ index = y * w + x
+ vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h)
+ keypoint_indices[n, k] = [index, vis]
+
+ return keypoint_indices
+
+ def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
+ raise NotImplementedError()
+
+ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
+ k: int):
+ """Get top-k response values from the heatmaps and corresponding tag
+ values from the tagging heatmaps.
+
+ Args:
+ batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
+ (B, K, H, W)
+ batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
+ the tag dim C is 2*K when using flip testing, or K otherwise
+ k (int): The number of top responses to get
+
+ Returns:
+ tuple:
+ - topk_vals (Tensor): Top-k response values of each heatmap in
+ shape (B, K, Topk)
+ - topk_tags (Tensor): The corresponding embedding tags of the
+ top-k responses, in shape (B, K, Topk, L)
+ - topk_locs (Tensor): The location of the top-k responses in each
+ heatmap, in shape (B, K, Topk, 2) where last dimension
+ represents x and y coordinates
+ """
+ B, K, H, W = batch_heatmaps.shape
+ L = batch_tags.shape[1] // K
+
+ # shape of topk_val, top_indices: (B, K, TopK)
+ topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk(
+ k, dim=-1)
+
+ topk_tags_per_kpts = [
+ torch.gather(_tag, dim=2, index=topk_indices)
+ for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1)
+ ]
+
+ topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L)
+ topk_locs = torch.stack([topk_indices % W, topk_indices // W],
+ dim=-1) # (B, K, TopK, 2)
+
+ return topk_vals, topk_tags, topk_locs
+
+ def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray,
+ batch_locs: np.ndarray):
+ """Group keypoints into groups (each represents an instance) by tags.
+
+ Args:
+ batch_vals (Tensor): Heatmap response values of keypoint
+ candidates in shape (B, K, Topk)
+ batch_tags (Tensor): Tags of keypoint candidates in shape
+ (B, K, Topk, L)
+ batch_locs (Tensor): Locations of keypoint candidates in shape
+ (B, K, Topk, 2)
+
+ Returns:
+ List[np.ndarray]: Grouping results of a batch, each element is a
+ np.ndarray (in shape [N, K, D+1]) that contains the groups
+ detected in an image, including both keypoint coordinates and
+ scores.
+ """
+
+ def _group_func(inputs: Tuple):
+ vals, tags, locs = inputs
+ return _group_keypoints_by_tags(
+ vals,
+ tags,
+ locs,
+ keypoint_order=self.dedecode_keypoint_order,
+ val_thr=self.decode_keypoint_thr,
+ tag_thr=self.decode_tag_thr,
+ max_groups=self.decode_max_instances)
+
+ _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs))
+ results = list(_results)
+ return results
+
+ def _fill_missing_keypoints(self, keypoints: np.ndarray,
+ keypoint_scores: np.ndarray,
+ heatmaps: np.ndarray, tags: np.ndarray):
+ """Fill the missing keypoints in the initial predictions.
+
+ Args:
+ keypoints (np.ndarray): Keypoint predictions in shape (N, K, D)
+ keypoint_scores (np.ndarray): Keypint score predictions in shape
+ (N, K), in which 0 means the corresponding keypoint is
+ missing in the initial prediction
+ heatmaps (np.ndarry): Heatmaps in shape (K, H, W)
+ tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where
+ C=L*K
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Keypoint predictions with missing
+ ones filled
+ - keypoint_scores (np.ndarray): Keypoint score predictions with
+ missing ones filled
+ """
+
+ N, K = keypoints.shape[:2]
+ H, W = heatmaps.shape[1:]
+ L = tags.shape[0] // K
+ keypoint_tags = [tags[k::K] for k in range(K)]
+
+ for n in range(N):
+ # Calculate the instance tag (mean tag of detected keypoints)
+ _tag = []
+ for k in range(K):
+ if keypoint_scores[n, k] > 0:
+ x, y = keypoints[n, k, :2].astype(np.int64)
+ x = np.clip(x, 0, W - 1)
+ y = np.clip(y, 0, H - 1)
+ _tag.append(keypoint_tags[k][:, y, x])
+
+ tag = np.mean(_tag, axis=0)
+ tag = tag.reshape(L, 1, 1)
+ # Search maximum response of the missing keypoints
+ for k in range(K):
+ if keypoint_scores[n, k] > 0:
+ continue
+ dist_map = np.linalg.norm(
+ keypoint_tags[k] - tag, ord=2, axis=0)
+ cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W
+ y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W))
+ keypoints[n, k] = [x, y]
+ keypoint_scores[n, k] = heatmaps[k, y, x]
+
+ return keypoints, keypoint_scores
+
+ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """Decode the keypoint coordinates from a batch of heatmaps and tagging
+ heatmaps. The decoded keypoint coordinates are in the input image
+ space.
+
+ Args:
+ batch_heatmaps (Tensor): Keypoint detection heatmaps in shape
+ (B, K, H, W)
+ batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where
+ :math:`C=L*K`
+
+ Returns:
+ tuple:
+ - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates
+ of the batch, each is in shape (N, K, D)
+ - batch_scores (List[np.ndarray]): Decoded keypoint scores of the
+ batch, each is in shape (N, K). It usually represents the
+ confidience of the keypoint prediction
+ """
+ B, _, H, W = batch_heatmaps.shape
+ assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), (
+ f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and '
+ f'tagging map ({batch_tags.shape})')
+
+ # Heatmap NMS
+ batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
+ self.decode_nms_kernel)
+
+ # Get top-k in each heatmap and and convert to numpy
+ batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
+ self._get_batch_topk(
+ batch_heatmaps, batch_tags, k=self.decode_topk))
+
+ # Group keypoint candidates into groups (instances)
+ batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags,
+ batch_topk_locs)
+
+ # Convert to numpy
+ batch_heatmaps_np = to_numpy(batch_heatmaps)
+ batch_tags_np = to_numpy(batch_tags)
+
+ # Refine the keypoint prediction
+ batch_keypoints = []
+ batch_keypoint_scores = []
+ for i, (groups, heatmaps, tags) in enumerate(
+ zip(batch_groups, batch_heatmaps_np, batch_tags_np)):
+
+ keypoints, scores = groups[..., :-1], groups[..., -1]
+
+ if keypoints.size > 0:
+ # identify missing keypoints
+ keypoints, scores = self._fill_missing_keypoints(
+ keypoints, scores, heatmaps, tags)
+
+ # refine keypoint coordinates according to heatmap distribution
+ if self.use_udp:
+ keypoints = refine_keypoints_dark_udp(
+ keypoints,
+ heatmaps,
+ blur_kernel_size=self.decode_gaussian_kernel)
+ else:
+ keypoints = refine_keypoints(keypoints, heatmaps)
+
+ batch_keypoints.append(keypoints)
+ batch_keypoint_scores.append(scores)
+
+ # restore keypoint scale
+ batch_keypoints = [
+ kpts * self.scale_factor for kpts in batch_keypoints
+ ]
+
+ return batch_keypoints, batch_keypoint_scores
diff --git a/mmpose/codecs/base.py b/mmpose/codecs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8479fdf1e1c5d51c0c0a2722c54b8a6c018c113
--- /dev/null
+++ b/mmpose/codecs/base.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from typing import Any, List, Optional, Tuple
+
+import numpy as np
+from mmengine.utils import is_method_overridden
+
+
+class BaseKeypointCodec(metaclass=ABCMeta):
+ """The base class of the keypoint codec.
+
+ A keypoint codec is a module to encode keypoint coordinates to specific
+ representation (e.g. heatmap) and vice versa. A subclass should implement
+ the methods :meth:`encode` and :meth:`decode`.
+ """
+
+ # pass additional encoding arguments to the `encode` method, beyond the
+ # mandatory `keypoints` and `keypoints_visible` arguments.
+ auxiliary_encode_keys = set()
+
+ @abstractmethod
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibility in shape
+ (N, K, D)
+
+ Returns:
+ dict: Encoded items.
+ """
+
+ @abstractmethod
+ def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoints.
+
+ Args:
+ encoded (any): Encoded keypoint representation using the codec
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
+ (N, K, D)
+ """
+
+ def batch_decode(self, batch_encoded: Any
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
+ """Decode keypoints.
+
+ Args:
+ batch_encoded (any): A batch of encoded keypoint
+ representations
+
+ Returns:
+ tuple:
+ - batch_keypoints (List[np.ndarray]): Each element is keypoint
+ coordinates in shape (N, K, D)
+ - batch_keypoints (List[np.ndarray]): Each element is keypoint
+ visibility in shape (N, K)
+ """
+ raise NotImplementedError()
+
+ @property
+ def support_batch_decoding(self) -> bool:
+ """Return whether the codec support decoding from batch data."""
+ return is_method_overridden('batch_decode', BaseKeypointCodec,
+ self.__class__)
diff --git a/mmpose/codecs/decoupled_heatmap.py b/mmpose/codecs/decoupled_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..da38a4ce2c825f5f362b19ee09d2515720c1a1f3
--- /dev/null
+++ b/mmpose/codecs/decoupled_heatmap.py
@@ -0,0 +1,265 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from typing import Optional, Tuple
+
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .utils import (generate_gaussian_heatmaps, get_diagonal_lengths,
+ get_instance_bbox, get_instance_root)
+from .utils.post_processing import get_heatmap_maximum
+from .utils.refinement import refine_keypoints
+
+
+@KEYPOINT_CODECS.register_module()
+class DecoupledHeatmap(BaseKeypointCodec):
+ """Encode/decode keypoints with the method introduced in the paper CID.
+
+ See the paper Contextual Instance Decoupling for Robust Multi-Person
+ Pose Estimation`_ by Wang et al (2022) for details
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+ - heatmaps (np.ndarray): The coupled heatmap in shape
+ (1+K, H, W) where [W, H] is the `heatmap_size`.
+ - instance_heatmaps (np.ndarray): The decoupled heatmap in shape
+ (M*K, H, W) where M is the number of instances.
+ - keypoint_weights (np.ndarray): The weight for heatmaps in shape
+ (M*K).
+ - instance_coords (np.ndarray): The coordinates of instance roots
+ in shape (M, 2)
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ root_type (str): The method to generate the instance root. Options
+ are:
+
+ - ``'kpt_center'``: Average coordinate of all visible keypoints.
+ - ``'bbox_center'``: Center point of bounding boxes outlined by
+ all visible keypoints.
+
+ Defaults to ``'kpt_center'``
+
+ heatmap_min_overlap (float): Minimum overlap rate among instances.
+ Used when calculating sigmas for instances. Defaults to 0.7
+ background_weight (float): Loss weight of background pixels.
+ Defaults to 0.1
+ encode_max_instances (int): The maximum number of instances
+ to encode for each sample. Defaults to 30
+
+ .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_
+ Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_
+ CVPR_2022_paper.html
+ """
+
+ # DecoupledHeatmap requires bounding boxes to determine the size of each
+ # instance, so that it can assign varying sigmas based on their size
+ auxiliary_encode_keys = {'bbox'}
+
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ root_type: str = 'kpt_center',
+ heatmap_min_overlap: float = 0.7,
+ encode_max_instances: int = 30,
+ ):
+ super().__init__()
+
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.root_type = root_type
+ self.encode_max_instances = encode_max_instances
+ self.heatmap_min_overlap = heatmap_min_overlap
+
+ self.scale_factor = (np.array(input_size) /
+ heatmap_size).astype(np.float32)
+
+ def _get_instance_wise_sigmas(
+ self,
+ bbox: np.ndarray,
+ ) -> np.ndarray:
+ """Get sigma values for each instance according to their size.
+
+ Args:
+ bbox (np.ndarray): Bounding box in shape (N, 4, 2)
+
+ Returns:
+ np.ndarray: Array containing the sigma values for each instance.
+ """
+ sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32)
+
+ heights = np.sqrt(np.power(bbox[:, 0] - bbox[:, 1], 2).sum(axis=-1))
+ widths = np.sqrt(np.power(bbox[:, 0] - bbox[:, 2], 2).sum(axis=-1))
+
+ for i in range(bbox.shape[0]):
+ h, w = heights[i], widths[i]
+
+ # compute sigma for each instance
+ # condition 1
+ a1, b1 = 1, h + w
+ c1 = w * h * (1 - self.heatmap_min_overlap) / (
+ 1 + self.heatmap_min_overlap)
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
+ r1 = (b1 + sq1) / 2
+
+ # condition 2
+ a2 = 4
+ b2 = 2 * (h + w)
+ c2 = (1 - self.heatmap_min_overlap) * w * h
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
+ r2 = (b2 + sq2) / 2
+
+ # condition 3
+ a3 = 4 * self.heatmap_min_overlap
+ b3 = -2 * self.heatmap_min_overlap * (h + w)
+ c3 = (self.heatmap_min_overlap - 1) * w * h
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / 2
+
+ sigmas[i] = min(r1, r2, r3) / 3
+
+ return sigmas
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None,
+ bbox: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints into heatmaps.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ bbox (np.ndarray): Bounding box in shape (N, 8) which includes
+ coordinates of 4 corners.
+
+ Returns:
+ dict:
+ - heatmaps (np.ndarray): The coupled heatmap in shape
+ (1+K, H, W) where [W, H] is the `heatmap_size`.
+ - instance_heatmaps (np.ndarray): The decoupled heatmap in shape
+ (N*K, H, W) where M is the number of instances.
+ - keypoint_weights (np.ndarray): The weight for heatmaps in shape
+ (N*K).
+ - instance_coords (np.ndarray): The coordinates of instance roots
+ in shape (N, 2)
+ """
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+ if bbox is None:
+ # generate pseudo bbox via visible keypoints
+ bbox = get_instance_bbox(keypoints, keypoints_visible)
+ bbox = np.tile(bbox, 2).reshape(-1, 4, 2)
+ # corner order: left_top, left_bottom, right_top, right_bottom
+ bbox[:, 1:3, 0] = bbox[:, 0:2, 0]
+
+ # keypoint coordinates in heatmap
+ _keypoints = keypoints / self.scale_factor
+ _bbox = bbox.reshape(-1, 4, 2) / self.scale_factor
+
+ # compute the root and scale of each instance
+ roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
+ self.root_type)
+
+ sigmas = self._get_instance_wise_sigmas(_bbox)
+
+ # generate global heatmaps
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=np.concatenate((_keypoints, roots[:, None]), axis=1),
+ keypoints_visible=np.concatenate(
+ (keypoints_visible, roots_visible[:, None]), axis=1),
+ sigma=sigmas)
+ roots_visible = keypoint_weights[:, -1]
+
+ # select instances
+ inst_roots, inst_indices = [], []
+ diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
+ for i in np.argsort(diagonal_lengths):
+ if roots_visible[i] < 1:
+ continue
+ # rand root point in 3x3 grid
+ x, y = roots[i] + np.random.randint(-1, 2, (2, ))
+ x = max(0, min(x, self.heatmap_size[0] - 1))
+ y = max(0, min(y, self.heatmap_size[1] - 1))
+ if (x, y) not in inst_roots:
+ inst_roots.append((x, y))
+ inst_indices.append(i)
+ if len(inst_indices) > self.encode_max_instances:
+ rand_indices = random.sample(
+ range(len(inst_indices)), self.encode_max_instances)
+ inst_roots = [inst_roots[i] for i in rand_indices]
+ inst_indices = [inst_indices[i] for i in rand_indices]
+
+ # generate instance-wise heatmaps
+ inst_heatmaps, inst_heatmap_weights = [], []
+ for i in inst_indices:
+ inst_heatmap, inst_heatmap_weight = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=_keypoints[i:i + 1],
+ keypoints_visible=keypoints_visible[i:i + 1],
+ sigma=sigmas[i].item())
+ inst_heatmaps.append(inst_heatmap)
+ inst_heatmap_weights.append(inst_heatmap_weight)
+
+ if len(inst_indices) > 0:
+ inst_heatmaps = np.concatenate(inst_heatmaps)
+ inst_heatmap_weights = np.concatenate(inst_heatmap_weights)
+ inst_roots = np.array(inst_roots, dtype=np.int32)
+ else:
+ inst_heatmaps = np.empty((0, *self.heatmap_size[::-1]))
+ inst_heatmap_weights = np.empty((0, ))
+ inst_roots = np.empty((0, 2), dtype=np.int32)
+
+ encoded = dict(
+ heatmaps=heatmaps,
+ instance_heatmaps=inst_heatmaps,
+ keypoint_weights=inst_heatmap_weights,
+ instance_coords=inst_roots)
+
+ return encoded
+
+ def decode(self, instance_heatmaps: np.ndarray,
+ instance_scores: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from decoupled heatmaps. The decoded
+ keypoint coordinates are in the input image space.
+
+ Args:
+ instance_heatmaps (np.ndarray): Heatmaps in shape (N, K, H, W)
+ instance_scores (np.ndarray): Confidence of instance roots
+ prediction in shape (N, 1)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
+ (N, K, D)
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
+ usually represents the confidence of the keypoint prediction
+ """
+ keypoints, keypoint_scores = [], []
+
+ for i in range(instance_heatmaps.shape[0]):
+ heatmaps = instance_heatmaps[i].copy()
+ kpts, scores = get_heatmap_maximum(heatmaps)
+ keypoints.append(refine_keypoints(kpts[None], heatmaps))
+ keypoint_scores.append(scores[None])
+
+ keypoints = np.concatenate(keypoints)
+ # Restore the keypoint scale
+ keypoints = keypoints * self.scale_factor
+
+ keypoint_scores = np.concatenate(keypoint_scores)
+ keypoint_scores *= instance_scores
+
+ return keypoints, keypoint_scores
diff --git a/mmpose/codecs/integral_regression_label.py b/mmpose/codecs/integral_regression_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8e72cb104bde62353d5e5c08fd40a2b8e635f6
--- /dev/null
+++ b/mmpose/codecs/integral_regression_label.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Optional, Tuple
+
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .msra_heatmap import MSRAHeatmap
+from .regression_label import RegressionLabel
+
+
+@KEYPOINT_CODECS.register_module()
+class IntegralRegressionLabel(BaseKeypointCodec):
+ """Generate keypoint coordinates and normalized heatmaps. See the paper:
+ `DSNT`_ by Nibali et al(2018).
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+
+ Encoded:
+
+ - keypoint_labels (np.ndarray): The normalized regression labels in
+ shape (N, K, D) where D is 2 for 2d coordinates
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where
+ [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Input image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ sigma (float): The sigma value of the Gaussian heatmap
+ unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
+ encoding. See `Dark Pose`_ for details. Defaults to ``False``
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation in DarkPose. The kernel size and sigma should follow
+ the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
+ Defaults to 11
+ normalize (bool): Whether to normalize the heatmaps. Defaults to True.
+
+ .. _`DSNT`: https://arxiv.org/abs/1801.07372
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ sigma: float,
+ unbiased: bool = False,
+ blur_kernel_size: int = 11,
+ normalize: bool = True) -> None:
+ super().__init__()
+
+ self.heatmap_codec = MSRAHeatmap(input_size, heatmap_size, sigma,
+ unbiased, blur_kernel_size)
+ self.keypoint_codec = RegressionLabel(input_size)
+ self.normalize = normalize
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encoding keypoints to regression labels and heatmaps.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - keypoint_labels (np.ndarray): The normalized regression labels in
+ shape (N, K, D) where D is 2 for 2d coordinates
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+ encoded_hm = self.heatmap_codec.encode(keypoints, keypoints_visible)
+ encoded_kp = self.keypoint_codec.encode(keypoints, keypoints_visible)
+
+ heatmaps = encoded_hm['heatmaps']
+ keypoint_labels = encoded_kp['keypoint_labels']
+ keypoint_weights = encoded_kp['keypoint_weights']
+
+ if self.normalize:
+ val_sum = heatmaps.sum(axis=(-1, -2)).reshape(-1, 1, 1) + 1e-24
+ heatmaps = heatmaps / val_sum
+
+ encoded = dict(
+ keypoint_labels=keypoint_labels,
+ heatmaps=heatmaps,
+ keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from normalized space to input image
+ space.
+
+ Args:
+ encoded (np.ndarray): Coordinates in shape (N, K, D)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
+ - socres (np.ndarray): The keypoint scores in shape (N, K).
+ It usually represents the confidence of the keypoint prediction
+ """
+
+ keypoints, scores = self.keypoint_codec.decode(encoded)
+
+ return keypoints, scores
diff --git a/mmpose/codecs/megvii_heatmap.py b/mmpose/codecs/megvii_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..e898004637b6d804cb256de369a64ed4d41e560d
--- /dev/null
+++ b/mmpose/codecs/megvii_heatmap.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import Optional, Tuple
+
+import cv2
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .utils import gaussian_blur, get_heatmap_maximum
+
+
+@KEYPOINT_CODECS.register_module()
+class MegviiHeatmap(BaseKeypointCodec):
+ """Represent keypoints as heatmaps via "Megvii" approach. See `MSPN`_
+ (2019) and `CPN`_ (2018) for details.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
+ where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ kernel_size (tuple): The kernel size of the heatmap gaussian in
+ [ks_x, ks_y]
+
+ .. _`MSPN`: https://arxiv.org/abs/1901.00148
+ .. _`CPN`: https://arxiv.org/abs/1711.07319
+ """
+
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ kernel_size: int,
+ ) -> None:
+
+ super().__init__()
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.kernel_size = kernel_size
+ self.scale_factor = (np.array(input_size) /
+ heatmap_size).astype(np.float32)
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints into heatmaps. Note that the original keypoint
+ coordinates should be in the input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+
+ N, K, _ = keypoints.shape
+ W, H = self.heatmap_size
+
+ assert N == 1, (
+ f'{self.__class__.__name__} only support single-instance '
+ 'keypoint encoding')
+
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
+ keypoint_weights = keypoints_visible.copy()
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ # get center coordinates
+ kx, ky = (keypoints[n, k] / self.scale_factor).astype(np.int64)
+ if kx < 0 or kx >= W or ky < 0 or ky >= H:
+ keypoint_weights[n, k] = 0
+ continue
+
+ heatmaps[k, ky, kx] = 1.
+ kernel_size = (self.kernel_size, self.kernel_size)
+ heatmaps[k] = cv2.GaussianBlur(heatmaps[k], kernel_size, 0)
+
+ # normalize the heatmap
+ heatmaps[k] = heatmaps[k] / heatmaps[k, ky, kx] * 255.
+
+ encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
+ coordinates are in the input image space.
+
+ Args:
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
+ (K, D)
+ - scores (np.ndarray): The keypoint scores in shape (K,). It
+ usually represents the confidence of the keypoint prediction
+ """
+ heatmaps = gaussian_blur(encoded.copy(), self.kernel_size)
+ K, H, W = heatmaps.shape
+
+ keypoints, scores = get_heatmap_maximum(heatmaps)
+
+ for k in range(K):
+ heatmap = heatmaps[k]
+ px = int(keypoints[k, 0])
+ py = int(keypoints[k, 1])
+ if 1 < px < W - 1 and 1 < py < H - 1:
+ diff = np.array([
+ heatmap[py][px + 1] - heatmap[py][px - 1],
+ heatmap[py + 1][px] - heatmap[py - 1][px]
+ ])
+ keypoints[k] += (np.sign(diff) * 0.25 + 0.5)
+
+ scores = scores / 255.0 + 0.5
+
+ # Unsqueeze the instance dimension for single-instance results
+ # and restore the keypoint scales
+ keypoints = keypoints[None] * self.scale_factor
+ scores = scores[None]
+
+ return keypoints, scores
diff --git a/mmpose/codecs/msra_heatmap.py b/mmpose/codecs/msra_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..63ba292e4de213d3151c6b0668257f1e77f3d195
--- /dev/null
+++ b/mmpose/codecs/msra_heatmap.py
@@ -0,0 +1,150 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple
+
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .utils.gaussian_heatmap import (generate_gaussian_heatmaps,
+ generate_unbiased_gaussian_heatmaps)
+from .utils.post_processing import get_heatmap_maximum
+from .utils.refinement import refine_keypoints, refine_keypoints_dark
+
+
+@KEYPOINT_CODECS.register_module()
+class MSRAHeatmap(BaseKeypointCodec):
+ """Represent keypoints as heatmaps via "MSRA" approach. See the paper:
+ `Simple Baselines for Human Pose Estimation and Tracking`_ by Xiao et al
+ (2018) for details.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+
+ - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W)
+ where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ sigma (float): The sigma value of the Gaussian heatmap
+ unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'``
+ encoding. See `Dark Pose`_ for details. Defaults to ``False``
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation in DarkPose. The kernel size and sigma should follow
+ the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`.
+ Defaults to 11
+
+ .. _`Simple Baselines for Human Pose Estimation and Tracking`:
+ https://arxiv.org/abs/1804.06208
+ .. _`Dark Pose`: https://arxiv.org/abs/1910.06278
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ sigma: float,
+ unbiased: bool = False,
+ blur_kernel_size: int = 11) -> None:
+ super().__init__()
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.sigma = sigma
+ self.unbiased = unbiased
+
+ # The Gaussian blur kernel size of the heatmap modulation
+ # in DarkPose and the sigma value follows the expirical
+ # formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`
+ # which gives:
+ # sigma~=3 if ks=17
+ # sigma=2 if ks=11;
+ # sigma~=1.5 if ks=7;
+ # sigma~=1 if ks=3;
+ self.blur_kernel_size = blur_kernel_size
+ self.scale_factor = (np.array(input_size) /
+ heatmap_size).astype(np.float32)
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints into heatmaps. Note that the original keypoint
+ coordinates should be in the input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+
+ assert keypoints.shape[0] == 1, (
+ f'{self.__class__.__name__} only support single-instance '
+ 'keypoint encoding')
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ if self.unbiased:
+ heatmaps, keypoint_weights = generate_unbiased_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=keypoints / self.scale_factor,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma)
+ else:
+ heatmaps, keypoint_weights = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=keypoints / self.scale_factor,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma)
+
+ encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
+ coordinates are in the input image space.
+
+ Args:
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
+ (N, K, D)
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
+ usually represents the confidence of the keypoint prediction
+ """
+ heatmaps = encoded.copy()
+ K, H, W = heatmaps.shape
+
+ keypoints, scores = get_heatmap_maximum(heatmaps)
+
+ # Unsqueeze the instance dimension for single-instance results
+ keypoints, scores = keypoints[None], scores[None]
+
+ if self.unbiased:
+ # Alleviate biased coordinate
+ keypoints = refine_keypoints_dark(
+ keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
+
+ else:
+ keypoints = refine_keypoints(keypoints, heatmaps)
+
+ # Restore the keypoint scale
+ keypoints = keypoints * self.scale_factor
+
+ return keypoints, scores
diff --git a/mmpose/codecs/regression_label.py b/mmpose/codecs/regression_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae385d2d97cebf3a7c097318b76a14626b0db89
--- /dev/null
+++ b/mmpose/codecs/regression_label.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Optional, Tuple
+
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+
+
+@KEYPOINT_CODECS.register_module()
+class RegressionLabel(BaseKeypointCodec):
+ r"""Generate keypoint coordinates.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+
+ Encoded:
+
+ - keypoint_labels (np.ndarray): The normalized regression labels in
+ shape (N, K, D) where D is 2 for 2d coordinates
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Input image size in [w, h]
+
+ """
+
+ def __init__(self, input_size: Tuple[int, int]) -> None:
+ super().__init__()
+
+ self.input_size = input_size
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encoding keypoints from input image space to normalized space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - keypoint_labels (np.ndarray): The normalized regression labels in
+ shape (N, K, D) where D is 2 for 2d coordinates
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ w, h = self.input_size
+ valid = ((keypoints >= 0) &
+ (keypoints <= [w - 1, h - 1])).all(axis=-1) & (
+ keypoints_visible > 0.5)
+
+ keypoint_labels = (keypoints / np.array([w, h])).astype(np.float32)
+ keypoint_weights = np.where(valid, 1., 0.).astype(np.float32)
+
+ encoded = dict(
+ keypoint_labels=keypoint_labels, keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from normalized space to input image
+ space.
+
+ Args:
+ encoded (np.ndarray): Coordinates in shape (N, K, D)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
+ - socres (np.ndarray): The keypoint scores in shape (N, K).
+ It usually represents the confidence of the keypoint prediction
+ """
+
+ if encoded.shape[-1] == 2:
+ N, K, _ = encoded.shape
+ normalized_coords = encoded.copy()
+ scores = np.ones((N, K), dtype=np.float32)
+ elif encoded.shape[-1] == 4:
+ # split coords and sigma if outputs contain output_sigma
+ normalized_coords = encoded[..., :2].copy()
+ output_sigma = encoded[..., 2:4].copy()
+
+ scores = (1 - output_sigma).mean(axis=-1)
+ else:
+ raise ValueError(
+ 'Keypoint dimension should be 2 or 4 (with sigma), '
+ f'but got {encoded.shape[-1]}')
+
+ w, h = self.input_size
+ keypoints = normalized_coords * np.array([w, h])
+
+ return keypoints, scores
diff --git a/mmpose/codecs/simcc_label.py b/mmpose/codecs/simcc_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..a22498c35291dfcd838ce1fe6451bd60dd0e6d48
--- /dev/null
+++ b/mmpose/codecs/simcc_label.py
@@ -0,0 +1,286 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import Optional, Tuple, Union
+
+import numpy as np
+
+from mmpose.codecs.utils import get_simcc_maximum
+from mmpose.codecs.utils.refinement import refine_simcc_dark
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+
+
+@KEYPOINT_CODECS.register_module()
+class SimCCLabel(BaseKeypointCodec):
+ r"""Generate keypoint representation via "SimCC" approach.
+ See the paper: `SimCC: a Simple Coordinate Classification Perspective for
+ Human Pose Estimation`_ by Li et al (2022) for more details.
+ Old name: SimDR
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+
+ Encoded:
+
+ - keypoint_x_labels (np.ndarray): The generated SimCC label for x-axis.
+ The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
+ and (N, K) if `smoothing_type=='standard'``, where
+ :math:`Wx=w*simcc_split_ratio`
+ - keypoint_y_labels (np.ndarray): The generated SimCC label for y-axis.
+ The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
+ and (N, K) if `smoothing_type=='standard'``, where
+ :math:`Wy=h*simcc_split_ratio`
+ - keypoint_weights (np.ndarray): The target weights in shape (N, K)
+
+ Args:
+ input_size (tuple): Input image size in [w, h]
+ smoothing_type (str): The SimCC label smoothing strategy. Options are
+ ``'gaussian'`` and ``'standard'``. Defaults to ``'gaussian'``
+ sigma (float | int | tuple): The sigma value in the Gaussian SimCC
+ label. Defaults to 6.0
+ simcc_split_ratio (float): The ratio of the label size to the input
+ size. For example, if the input width is ``w``, the x label size
+ will be :math:`w*simcc_split_ratio`. Defaults to 2.0
+ label_smooth_weight (float): Label Smoothing weight. Defaults to 0.0
+ normalize (bool): Whether to normalize the heatmaps. Defaults to True.
+
+ .. _`SimCC: a Simple Coordinate Classification Perspective for Human Pose
+ Estimation`: https://arxiv.org/abs/2107.03332
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ smoothing_type: str = 'gaussian',
+ sigma: Union[float, int, Tuple[float]] = 6.0,
+ simcc_split_ratio: float = 2.0,
+ label_smooth_weight: float = 0.0,
+ normalize: bool = True,
+ use_dark: bool = False) -> None:
+ super().__init__()
+
+ self.input_size = input_size
+ self.smoothing_type = smoothing_type
+ self.simcc_split_ratio = simcc_split_ratio
+ self.label_smooth_weight = label_smooth_weight
+ self.normalize = normalize
+ self.use_dark = use_dark
+
+ if isinstance(sigma, (float, int)):
+ self.sigma = np.array([sigma, sigma])
+ else:
+ self.sigma = np.array(sigma)
+
+ if self.smoothing_type not in {'gaussian', 'standard'}:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid `smoothing_type` value'
+ f'{self.smoothing_type}. Should be one of '
+ '{"gaussian", "standard"}')
+
+ if self.smoothing_type == 'gaussian' and self.label_smooth_weight > 0:
+ raise ValueError('Attribute `label_smooth_weight` is only '
+ 'used for `standard` mode.')
+
+ if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0:
+ raise ValueError('`label_smooth_weight` should be in range [0, 1]')
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encoding keypoints into SimCC labels. Note that the original
+ keypoint coordinates should be in the input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - keypoint_x_labels (np.ndarray): The generated SimCC label for
+ x-axis.
+ The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'``
+ and (N, K) if `smoothing_type=='standard'``, where
+ :math:`Wx=w*simcc_split_ratio`
+ - keypoint_y_labels (np.ndarray): The generated SimCC label for
+ y-axis.
+ The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'``
+ and (N, K) if `smoothing_type=='standard'``, where
+ :math:`Wy=h*simcc_split_ratio`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ if self.smoothing_type == 'gaussian':
+ x_labels, y_labels, keypoint_weights = self._generate_gaussian(
+ keypoints, keypoints_visible)
+ elif self.smoothing_type == 'standard':
+ x_labels, y_labels, keypoint_weights = self._generate_standard(
+ keypoints, keypoints_visible)
+ else:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid `smoothing_type` value'
+ f'{self.smoothing_type}. Should be one of '
+ '{"gaussian", "standard"}')
+
+ encoded = dict(
+ keypoint_x_labels=x_labels,
+ keypoint_y_labels=y_labels,
+ keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, simcc_x: np.ndarray,
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from SimCC representations. The decoded
+ coordinates are in the input image space.
+
+ Args:
+ encoded (Tuple[np.ndarray, np.ndarray]): SimCC labels for x-axis
+ and y-axis
+ simcc_x (np.ndarray): SimCC label for x-axis
+ simcc_y (np.ndarray): SimCC label for y-axis
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
+ - socres (np.ndarray): The keypoint scores in shape (N, K).
+ It usually represents the confidence of the keypoint prediction
+ """
+
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
+
+ # Unsqueeze the instance dimension for single-instance results
+ if keypoints.ndim == 2:
+ keypoints = keypoints[None, :]
+ scores = scores[None, :]
+
+ if self.use_dark:
+ x_blur = int((self.sigma[0] * 20 - 7) // 3)
+ y_blur = int((self.sigma[1] * 20 - 7) // 3)
+ x_blur -= int((x_blur % 2) == 0)
+ y_blur -= int((y_blur % 2) == 0)
+ keypoints[:, :, 0] = refine_simcc_dark(keypoints[:, :, 0], simcc_x,
+ x_blur)
+ keypoints[:, :, 1] = refine_simcc_dark(keypoints[:, :, 1], simcc_y,
+ y_blur)
+
+ keypoints /= self.simcc_split_ratio
+
+ return keypoints, scores
+
+ def _map_coordinates(
+ self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Mapping keypoint coordinates into SimCC space."""
+
+ keypoints_split = keypoints.copy()
+ keypoints_split = np.around(keypoints_split * self.simcc_split_ratio)
+ keypoints_split = keypoints_split.astype(np.int64)
+ keypoint_weights = keypoints_visible.copy()
+
+ return keypoints_split, keypoint_weights
+
+ def _generate_standard(
+ self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Encoding keypoints into SimCC labels with Standard Label Smoothing
+ strategy.
+
+ Labels will be one-hot vectors if self.label_smooth_weight==0.0
+ """
+
+ N, K, _ = keypoints.shape
+ w, h = self.input_size
+ W = np.around(w * self.simcc_split_ratio).astype(int)
+ H = np.around(h * self.simcc_split_ratio).astype(int)
+
+ keypoints_split, keypoint_weights = self._map_coordinates(
+ keypoints, keypoints_visible)
+
+ target_x = np.zeros((N, K, W), dtype=np.float32)
+ target_y = np.zeros((N, K, H), dtype=np.float32)
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ # get center coordinates
+ mu_x, mu_y = keypoints_split[n, k].astype(np.int64)
+
+ # detect abnormal coords and assign the weight 0
+ if mu_x >= W or mu_y >= H or mu_x < 0 or mu_y < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ if self.label_smooth_weight > 0:
+ target_x[n, k] = self.label_smooth_weight / (W - 1)
+ target_y[n, k] = self.label_smooth_weight / (H - 1)
+
+ target_x[n, k, mu_x] = 1.0 - self.label_smooth_weight
+ target_y[n, k, mu_y] = 1.0 - self.label_smooth_weight
+
+ return target_x, target_y, keypoint_weights
+
+ def _generate_gaussian(
+ self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Encoding keypoints into SimCC labels with Gaussian Label Smoothing
+ strategy."""
+
+ N, K, _ = keypoints.shape
+ w, h = self.input_size
+ W = np.around(w * self.simcc_split_ratio).astype(int)
+ H = np.around(h * self.simcc_split_ratio).astype(int)
+
+ keypoints_split, keypoint_weights = self._map_coordinates(
+ keypoints, keypoints_visible)
+
+ target_x = np.zeros((N, K, W), dtype=np.float32)
+ target_y = np.zeros((N, K, H), dtype=np.float32)
+
+ # 3-sigma rule
+ radius = self.sigma * 3
+
+ # xy grid
+ x = np.arange(0, W, 1, dtype=np.float32)
+ y = np.arange(0, H, 1, dtype=np.float32)
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ mu = keypoints_split[n, k]
+
+ # check that the gaussian has in-bounds part
+ left, top = mu - radius
+ right, bottom = mu + radius + 1
+
+ if left >= W or top >= H or right < 0 or bottom < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ mu_x, mu_y = mu
+
+ target_x[n, k] = np.exp(-((x - mu_x)**2) / (2 * self.sigma[0]**2))
+ target_y[n, k] = np.exp(-((y - mu_y)**2) / (2 * self.sigma[1]**2))
+
+ if self.normalize:
+ norm_value = self.sigma * np.sqrt(np.pi * 2)
+ target_x /= norm_value[0]
+ target_y /= norm_value[1]
+
+ return target_x, target_y, keypoint_weights
diff --git a/mmpose/codecs/spr.py b/mmpose/codecs/spr.py
new file mode 100644
index 0000000000000000000000000000000000000000..add6f5715b52a42d58114bcb3d432ea38cfec29f
--- /dev/null
+++ b/mmpose/codecs/spr.py
@@ -0,0 +1,299 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
+ generate_gaussian_heatmaps, get_diagonal_lengths,
+ get_instance_root)
+
+
+@KEYPOINT_CODECS.register_module()
+class SPR(BaseKeypointCodec):
+ """Encode/decode keypoints with Structured Pose Representation (SPR).
+
+ See the paper `Single-stage multi-person pose machines`_
+ by Nie et al (2017) for details
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+
+ - heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
+ where [W, H] is the `heatmap_size`. If the keypoint heatmap is
+ generated together, the output heatmap shape is (K+1, H, W)
+ - heatmap_weights (np.ndarray): The target weights for heatmaps which
+ has same shape with heatmaps.
+ - displacements (np.ndarray): The dense keypoint displacement in
+ shape (K*2, H, W).
+ - displacement_weights (np.ndarray): The target weights for heatmaps
+ which has same shape with displacements.
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ sigma (float or tuple, optional): The sigma values of the Gaussian
+ heatmaps. If sigma is a tuple, it includes both sigmas for root
+ and keypoint heatmaps. ``None`` means the sigmas are computed
+ automatically from the heatmap size. Defaults to ``None``
+ generate_keypoint_heatmaps (bool): Whether to generate Gaussian
+ heatmaps for each keypoint. Defaults to ``False``
+ root_type (str): The method to generate the instance root. Options
+ are:
+
+ - ``'kpt_center'``: Average coordinate of all visible keypoints.
+ - ``'bbox_center'``: Center point of bounding boxes outlined by
+ all visible keypoints.
+
+ Defaults to ``'kpt_center'``
+
+ minimal_diagonal_length (int or float): The threshold of diagonal
+ length of instance bounding box. Small instances will not be
+ used in training. Defaults to 32
+ background_weight (float): Loss weight of background pixels.
+ Defaults to 0.1
+ decode_thr (float): The threshold of keypoint response value in
+ heatmaps. Defaults to 0.01
+ decode_nms_kernel (int): The kernel size of the NMS during decoding,
+ which should be an odd integer. Defaults to 5
+ decode_max_instances (int): The maximum number of instances
+ to decode. Defaults to 30
+
+ .. _`Single-stage multi-person pose machines`:
+ https://arxiv.org/abs/1908.09220
+ """
+
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ sigma: Optional[Union[float, Tuple[float]]] = None,
+ generate_keypoint_heatmaps: bool = False,
+ root_type: str = 'kpt_center',
+ minimal_diagonal_length: Union[int, float] = 5,
+ background_weight: float = 0.1,
+ decode_nms_kernel: int = 5,
+ decode_max_instances: int = 30,
+ decode_thr: float = 0.01,
+ ):
+ super().__init__()
+
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
+ self.root_type = root_type
+ self.minimal_diagonal_length = minimal_diagonal_length
+ self.background_weight = background_weight
+ self.decode_nms_kernel = decode_nms_kernel
+ self.decode_max_instances = decode_max_instances
+ self.decode_thr = decode_thr
+
+ self.scale_factor = (np.array(input_size) /
+ heatmap_size).astype(np.float32)
+
+ if sigma is None:
+ sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
+ if generate_keypoint_heatmaps:
+ # sigma for root heatmap and keypoint heatmaps
+ self.sigma = (sigma, sigma // 2)
+ else:
+ self.sigma = (sigma, )
+ else:
+ if not isinstance(sigma, (tuple, list)):
+ sigma = (sigma, )
+ if generate_keypoint_heatmaps:
+ assert len(sigma) == 2, 'sigma for keypoints must be given ' \
+ 'if `generate_keypoint_heatmaps` ' \
+ 'is True. e.g. sigma=(4, 2)'
+ self.sigma = sigma
+
+ def _get_heatmap_weights(self,
+ heatmaps,
+ fg_weight: float = 1,
+ bg_weight: float = 0):
+ """Generate weight array for heatmaps.
+
+ Args:
+ heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
+ fg_weight (float): Weight for foreground pixels. Defaults to 1.0
+ bg_weight (float): Weight for background pixels. Defaults to 0.0
+
+ Returns:
+ np.ndarray: Heatmap weight array in the same shape with heatmaps
+ """
+ heatmap_weights = np.ones(heatmaps.shape) * bg_weight
+ heatmap_weights[heatmaps > 0] = fg_weight
+ return heatmap_weights
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints into root heatmaps and keypoint displacement
+ fields. Note that the original keypoint coordinates should be in the
+ input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (1, H, W) where [W, H] is the `heatmap_size`. If keypoint
+ heatmaps are generated together, the shape is (K+1, H, W)
+ - heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
+ which has same shape with `heatmaps`
+ - displacements (np.ndarray): The generated displacement fields in
+ shape (K*D, H, W). The vector on each pixels represents the
+ displacement of keypoints belong to the associated instance
+ from this pixel.
+ - displacement_weights (np.ndarray): The pixel-wise weight for
+ displacements which has same shape with `displacements`
+ """
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ # keypoint coordinates in heatmap
+ _keypoints = keypoints / self.scale_factor
+
+ # compute the root and scale of each instance
+ roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
+ self.root_type)
+ diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
+
+ # discard the small instances
+ roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0
+
+ # generate heatmaps
+ heatmaps, _ = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=roots[:, None],
+ keypoints_visible=roots_visible[:, None],
+ sigma=self.sigma[0])
+ heatmap_weights = self._get_heatmap_weights(
+ heatmaps, bg_weight=self.background_weight)
+
+ if self.generate_keypoint_heatmaps:
+ keypoint_heatmaps, _ = generate_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=_keypoints,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma[1])
+
+ keypoint_heatmaps_weights = self._get_heatmap_weights(
+ keypoint_heatmaps, bg_weight=self.background_weight)
+
+ heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
+ heatmap_weights = np.concatenate(
+ (keypoint_heatmaps_weights, heatmap_weights), axis=0)
+
+ # generate displacements
+ displacements, displacement_weights = \
+ generate_displacement_heatmap(
+ self.heatmap_size,
+ _keypoints,
+ keypoints_visible,
+ roots,
+ roots_visible,
+ diagonal_lengths,
+ self.sigma[0],
+ )
+
+ encoded = dict(
+ heatmaps=heatmaps,
+ heatmap_weights=heatmap_weights,
+ displacements=displacements,
+ displacement_weights=displacement_weights)
+
+ return encoded
+
+ def decode(self, heatmaps: Tensor,
+ displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode the keypoint coordinates from heatmaps and displacements. The
+ decoded keypoint coordinates are in the input image space.
+
+ Args:
+ heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
+ in shape (1, H, W) or (K+1, H, W)
+ displacements (Tensor): Encoded keypoints displacement fields
+ in shape (K*D, H, W)
+
+ Returns:
+ tuple:
+ - keypoints (Tensor): Decoded keypoint coordinates in shape
+ (N, K, D)
+ - scores (tuple):
+ - root_scores (Tensor): The root scores in shape (N, )
+ - keypoint_scores (Tensor): The keypoint scores in
+ shape (N, K). If keypoint heatmaps are not generated,
+ `keypoint_scores` will be `None`
+ """
+ # heatmaps, displacements = encoded
+ _k, h, w = displacements.shape
+ k = _k // 2
+ displacements = displacements.view(k, 2, h, w)
+
+ # convert displacements to a dense keypoint prediction
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
+ regular_grid = torch.stack([x, y], dim=0).to(displacements)
+ posemaps = (regular_grid[None] + displacements).flatten(2)
+
+ # find local maximum on root heatmap
+ root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
+ self.decode_nms_kernel)
+ root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
+ self.decode_max_instances)
+ mask = root_scores > self.decode_thr
+ root_scores, pos_idx = root_scores[mask], pos_idx[mask]
+
+ keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()
+
+ if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
+ # compute scores for each keypoint
+ keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
+ else:
+ keypoint_scores = None
+
+ keypoints = torch.cat([
+ kpt * self.scale_factor[i]
+ for i, kpt in enumerate(keypoints.split(1, -1))
+ ],
+ dim=-1)
+ return keypoints, (root_scores, keypoint_scores)
+
+ def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
+ """Calculate the keypoint scores with keypoints heatmaps and
+ coordinates.
+
+ Args:
+ heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
+ keypoints (Tensor): Keypoint coordinates in shape (N, K, D)
+
+ Returns:
+ Tensor: Keypoint scores in [N, K]
+ """
+ k, h, w = heatmaps.shape
+ keypoints = torch.stack((
+ keypoints[..., 0] / (w - 1) * 2 - 1,
+ keypoints[..., 1] / (h - 1) * 2 - 1,
+ ),
+ dim=-1)
+ keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()
+
+ keypoint_scores = torch.nn.functional.grid_sample(
+ heatmaps.unsqueeze(1), keypoints,
+ padding_mode='border').view(k, -1).transpose(0, 1).contiguous()
+
+ return keypoint_scores
diff --git a/mmpose/codecs/udp_heatmap.py b/mmpose/codecs/udp_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..c38ea17be4327e6f7d433198c6a78d5ad2869342
--- /dev/null
+++ b/mmpose/codecs/udp_heatmap.py
@@ -0,0 +1,185 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple
+
+import cv2
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps,
+ get_heatmap_maximum, refine_keypoints_dark_udp)
+
+
+@KEYPOINT_CODECS.register_module()
+class UDPHeatmap(BaseKeypointCodec):
+ r"""Generate keypoint heatmaps by Unbiased Data Processing (UDP).
+ See the paper: `The Devil is in the Details: Delving into Unbiased Data
+ Processing for Human Pose Estimation`_ by Huang et al (2020) for details.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+ - heatmap size: [W, H]
+
+ Encoded:
+
+ - heatmap (np.ndarray): The generated heatmap in shape (C_out, H, W)
+ where [W, H] is the `heatmap_size`, and the C_out is the output
+ channel number which depends on the `heatmap_type`. If
+ `heatmap_type=='gaussian'`, C_out equals to keypoint number K;
+ if `heatmap_type=='combined'`, C_out equals to K*3
+ (x_offset, y_offset and class label)
+ - keypoint_weights (np.ndarray): The target weights in shape (K,)
+
+ Args:
+ input_size (tuple): Image size in [w, h]
+ heatmap_size (tuple): Heatmap size in [W, H]
+ heatmap_type (str): The heatmap type to encode the keypoitns. Options
+ are:
+
+ - ``'gaussian'``: Gaussian heatmap
+ - ``'combined'``: Combination of a binary label map and offset
+ maps for X and Y axes.
+
+ sigma (float): The sigma value of the Gaussian heatmap when
+ ``heatmap_type=='gaussian'``. Defaults to 2.0
+ radius_factor (float): The radius factor of the binary label
+ map when ``heatmap_type=='combined'``. The positive region is
+ defined as the neighbor of the keypoit with the radius
+ :math:`r=radius_factor*max(W, H)`. Defaults to 0.0546875
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation in DarkPose. Defaults to 11
+
+ .. _`The Devil is in the Details: Delving into Unbiased Data Processing for
+ Human Pose Estimation`: https://arxiv.org/abs/1911.07524
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ heatmap_size: Tuple[int, int],
+ heatmap_type: str = 'gaussian',
+ sigma: float = 2.,
+ radius_factor: float = 0.0546875,
+ blur_kernel_size: int = 11) -> None:
+ super().__init__()
+ self.input_size = input_size
+ self.heatmap_size = heatmap_size
+ self.sigma = sigma
+ self.radius_factor = radius_factor
+ self.heatmap_type = heatmap_type
+ self.blur_kernel_size = blur_kernel_size
+ self.scale_factor = ((np.array(input_size) - 1) /
+ (np.array(heatmap_size) - 1)).astype(np.float32)
+
+ if self.heatmap_type not in {'gaussian', 'combined'}:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
+ f'{self.heatmap_type}. Should be one of '
+ '{"gaussian", "combined"}')
+
+ def encode(self,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None) -> dict:
+ """Encode keypoints into heatmaps. Note that the original keypoint
+ coordinates should be in the input image space.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ dict:
+ - heatmap (np.ndarray): The generated heatmap in shape
+ (C_out, H, W) where [W, H] is the `heatmap_size`, and the
+ C_out is the output channel number which depends on the
+ `heatmap_type`. If `heatmap_type=='gaussian'`, C_out equals to
+ keypoint number K; if `heatmap_type=='combined'`, C_out
+ equals to K*3 (x_offset, y_offset and class label)
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (K,)
+ """
+ assert keypoints.shape[0] == 1, (
+ f'{self.__class__.__name__} only support single-instance '
+ 'keypoint encoding')
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ if self.heatmap_type == 'gaussian':
+ heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps(
+ heatmap_size=self.heatmap_size,
+ keypoints=keypoints / self.scale_factor,
+ keypoints_visible=keypoints_visible,
+ sigma=self.sigma)
+ elif self.heatmap_type == 'combined':
+ heatmaps, keypoint_weights = generate_offset_heatmap(
+ heatmap_size=self.heatmap_size,
+ keypoints=keypoints / self.scale_factor,
+ keypoints_visible=keypoints_visible,
+ radius_factor=self.radius_factor)
+ else:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid `heatmap_type` value'
+ f'{self.heatmap_type}. Should be one of '
+ '{"gaussian", "combined"}')
+
+ encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights)
+
+ return encoded
+
+ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Decode keypoint coordinates from heatmaps. The decoded keypoint
+ coordinates are in the input image space.
+
+ Args:
+ encoded (np.ndarray): Heatmaps in shape (K, H, W)
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Decoded keypoint coordinates in shape
+ (N, K, D)
+ - scores (np.ndarray): The keypoint scores in shape (N, K). It
+ usually represents the confidence of the keypoint prediction
+ """
+ heatmaps = encoded.copy()
+
+ if self.heatmap_type == 'gaussian':
+ keypoints, scores = get_heatmap_maximum(heatmaps)
+ # unsqueeze the instance dimension for single-instance results
+ keypoints = keypoints[None]
+ scores = scores[None]
+
+ keypoints = refine_keypoints_dark_udp(
+ keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size)
+
+ elif self.heatmap_type == 'combined':
+ _K, H, W = heatmaps.shape
+ K = _K // 3
+
+ for cls_heatmap in heatmaps[::3]:
+ # Apply Gaussian blur on classification maps
+ ks = 2 * self.blur_kernel_size + 1
+ cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap)
+
+ # valid radius
+ radius = self.radius_factor * max(W, H)
+
+ x_offset = heatmaps[1::3].flatten() * radius
+ y_offset = heatmaps[2::3].flatten() * radius
+ keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3])
+ index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten()
+ index += W * H * np.arange(0, K)
+ index = index.astype(int)
+ keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1)
+ # unsqueeze the instance dimension for single-instance results
+ keypoints = keypoints[None].astype(np.float32)
+ scores = scores[None]
+
+ W, H = self.heatmap_size
+ keypoints = keypoints / [W - 1, H - 1] * self.input_size
+
+ return keypoints, scores
diff --git a/mmpose/codecs/utils/__init__.py b/mmpose/codecs/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaa093f12bc6d69748509721b4b481811583339c
--- /dev/null
+++ b/mmpose/codecs/utils/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .gaussian_heatmap import (generate_gaussian_heatmaps,
+ generate_udp_gaussian_heatmaps,
+ generate_unbiased_gaussian_heatmaps)
+from .instance_property import (get_diagonal_lengths, get_instance_bbox,
+ get_instance_root)
+from .offset_heatmap import (generate_displacement_heatmap,
+ generate_offset_heatmap)
+from .post_processing import (batch_heatmap_nms, gaussian_blur,
+ gaussian_blur1d, get_heatmap_maximum,
+ get_simcc_maximum, get_simcc_normalized)
+from .refinement import (refine_keypoints, refine_keypoints_dark,
+ refine_keypoints_dark_udp, refine_simcc_dark)
+
+__all__ = [
+ 'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps',
+ 'generate_unbiased_gaussian_heatmaps', 'gaussian_blur',
+ 'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap',
+ 'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark',
+ 'refine_keypoints_dark_udp', 'generate_displacement_heatmap',
+ 'refine_simcc_dark', 'gaussian_blur1d', 'get_diagonal_lengths',
+ 'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized'
+]
diff --git a/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf9afad4dbcabd24bbe2146f2dcfc482498dd53e
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..851accf31e4c11ea2a4f7fa703f01d8837b1880b
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/gaussian_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b80bd7e07b2d49c3760825295676502eb4ddf5d
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/instance_property.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efa9e9e323b38f7dad803d265cfac7171517c794
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/offset_heatmap.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c358fd906e0a050246a7839ff1e2cbfe9362658
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/post_processing.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc b/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c931f875f52229b1749016d40991bcc1d8b0dcb
Binary files /dev/null and b/mmpose/codecs/utils/__pycache__/refinement.cpython-38.pyc differ
diff --git a/mmpose/codecs/utils/gaussian_heatmap.py b/mmpose/codecs/utils/gaussian_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..91e08c2cdd41ddcb751e3afd13555731731a0602
--- /dev/null
+++ b/mmpose/codecs/utils/gaussian_heatmap.py
@@ -0,0 +1,227 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import Tuple, Union
+
+import numpy as np
+
+
+def generate_gaussian_heatmaps(
+ heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray,
+ sigma: Union[float, Tuple[float], np.ndarray],
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Generate gaussian heatmaps of keypoints.
+
+ Args:
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ sigma (float or List[float]): A list of sigma values of the Gaussian
+ heatmap for each instance. If sigma is given as a single float
+ value, it will be expanded into a tuple
+
+ Returns:
+ tuple:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+ """
+
+ N, K, _ = keypoints.shape
+ W, H = heatmap_size
+
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
+ keypoint_weights = keypoints_visible.copy()
+
+ if isinstance(sigma, (int, float)):
+ sigma = (sigma, ) * N
+
+ for n in range(N):
+ # 3-sigma rule
+ radius = sigma[n] * 3
+
+ # xy grid
+ gaussian_size = 2 * radius + 1
+ x = np.arange(0, gaussian_size, 1, dtype=np.float32)
+ y = x[:, None]
+ x0 = y0 = gaussian_size // 2
+
+ for k in range(K):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ # get gaussian center coordinates
+ mu = (keypoints[n, k] + 0.5).astype(np.int64)
+
+ # check that the gaussian has in-bounds part
+ left, top = (mu - radius).astype(np.int64)
+ right, bottom = (mu + radius + 1).astype(np.int64)
+
+ if left >= W or top >= H or right < 0 or bottom < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ # The gaussian is not normalized,
+ # we want the center value to equal 1
+ gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2))
+
+ # valid range in gaussian
+ g_x1 = max(0, -left)
+ g_x2 = min(W, right) - left
+ g_y1 = max(0, -top)
+ g_y2 = min(H, bottom) - top
+
+ # valid range in heatmap
+ h_x1 = max(0, left)
+ h_x2 = min(W, right)
+ h_y1 = max(0, top)
+ h_y2 = min(H, bottom)
+
+ heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
+ gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
+
+ _ = np.maximum(
+ heatmap_region, gaussian_regsion, out=heatmap_region)
+
+ return heatmaps, keypoint_weights
+
+
+def generate_unbiased_gaussian_heatmaps(
+ heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray,
+ sigma: float,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Generate gaussian heatmaps of keypoints using `Dark Pose`_.
+
+ Args:
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ tuple:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+
+ .. _`Dark Pose`: https://arxiv.org/abs/1910.06278
+ """
+
+ N, K, _ = keypoints.shape
+ W, H = heatmap_size
+
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
+ keypoint_weights = keypoints_visible.copy()
+
+ # 3-sigma rule
+ radius = sigma * 3
+
+ # xy grid
+ x = np.arange(0, W, 1, dtype=np.float32)
+ y = np.arange(0, H, 1, dtype=np.float32)[:, None]
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ mu = keypoints[n, k]
+ # check that the gaussian has in-bounds part
+ left, top = mu - radius
+ right, bottom = mu + radius + 1
+
+ if left >= W or top >= H or right < 0 or bottom < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ gaussian = np.exp(-((x - mu[0])**2 + (y - mu[1])**2) / (2 * sigma**2))
+
+ _ = np.maximum(gaussian, heatmaps[k], out=heatmaps[k])
+
+ return heatmaps, keypoint_weights
+
+
+def generate_udp_gaussian_heatmaps(
+ heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray,
+ sigma: float,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Generate gaussian heatmaps of keypoints using `UDP`_.
+
+ Args:
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ sigma (float): The sigma value of the Gaussian heatmap
+
+ Returns:
+ tuple:
+ - heatmaps (np.ndarray): The generated heatmap in shape
+ (K, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (N, K)
+
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
+ """
+
+ N, K, _ = keypoints.shape
+ W, H = heatmap_size
+
+ heatmaps = np.zeros((K, H, W), dtype=np.float32)
+ keypoint_weights = keypoints_visible.copy()
+
+ # 3-sigma rule
+ radius = sigma * 3
+
+ # xy grid
+ gaussian_size = 2 * radius + 1
+ x = np.arange(0, gaussian_size, 1, dtype=np.float32)
+ y = x[:, None]
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ mu = (keypoints[n, k] + 0.5).astype(np.int64)
+ # check that the gaussian has in-bounds part
+ left, top = (mu - radius).astype(np.int64)
+ right, bottom = (mu + radius + 1).astype(np.int64)
+
+ if left >= W or top >= H or right < 0 or bottom < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ mu_ac = keypoints[n, k]
+ x0 = y0 = gaussian_size // 2
+ x0 += mu_ac[0] - mu[0]
+ y0 += mu_ac[1] - mu[1]
+ gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
+
+ # valid range in gaussian
+ g_x1 = max(0, -left)
+ g_x2 = min(W, right) - left
+ g_y1 = max(0, -top)
+ g_y2 = min(H, bottom) - top
+
+ # valid range in heatmap
+ h_x1 = max(0, left)
+ h_x2 = min(W, right)
+ h_y1 = max(0, top)
+ h_y2 = min(H, bottom)
+
+ heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2]
+ gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2]
+
+ _ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region)
+
+ return heatmaps, keypoint_weights
diff --git a/mmpose/codecs/utils/instance_property.py b/mmpose/codecs/utils/instance_property.py
new file mode 100644
index 0000000000000000000000000000000000000000..15ae30aef021939e2f0dbf276ce8b1c3cceaa40e
--- /dev/null
+++ b/mmpose/codecs/utils/instance_property.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import numpy as np
+
+
+def get_instance_root(keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None,
+ root_type: str = 'kpt_center') -> np.ndarray:
+ """Calculate the coordinates and visibility of instance roots.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ root_type (str): Calculation of instance roots which should
+ be one of the following options:
+
+ - ``'kpt_center'``: The roots' coordinates are the mean
+ coordinates of visible keypoints
+ - ``'bbox_center'``: The roots' are the center of bounding
+ boxes outlined by visible keypoints
+
+ Defaults to ``'kpt_center'``
+
+ Returns:
+ tuple
+ - roots_coordinate(np.ndarray): Coordinates of instance roots in
+ shape [N, D]
+ - roots_visible(np.ndarray): Visibility of instance roots in
+ shape [N]
+ """
+
+ roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32)
+ roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2
+
+ for i in range(keypoints.shape[0]):
+
+ # collect visible keypoints
+ if keypoints_visible is not None:
+ visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
+ else:
+ visible_keypoints = keypoints[i]
+ if visible_keypoints.size == 0:
+ roots_visible[i] = 0
+ continue
+
+ # compute the instance root with visible keypoints
+ if root_type == 'kpt_center':
+ roots_coordinate[i] = visible_keypoints.mean(axis=0)
+ roots_visible[i] = 1
+ elif root_type == 'bbox_center':
+ roots_coordinate[i] = (visible_keypoints.max(axis=0) +
+ visible_keypoints.min(axis=0)) / 2.0
+ roots_visible[i] = 1
+ else:
+ raise ValueError(
+ f'the value of `root_type` must be \'kpt_center\' or '
+ f'\'bbox_center\', but got \'{root_type}\'')
+
+ return roots_coordinate, roots_visible
+
+
+def get_instance_bbox(keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> np.ndarray:
+ """Calculate the pseudo instance bounding box from visible keypoints. The
+ bounding boxes are in the xyxy format.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ np.ndarray: bounding boxes in [N, 4]
+ """
+ bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32)
+ for i in range(keypoints.shape[0]):
+ if keypoints_visible is not None:
+ visible_keypoints = keypoints[i][keypoints_visible[i] > 0]
+ else:
+ visible_keypoints = keypoints[i]
+ if visible_keypoints.size == 0:
+ continue
+
+ bbox[i, :2] = visible_keypoints.min(axis=0)
+ bbox[i, 2:] = visible_keypoints.max(axis=0)
+ return bbox
+
+
+def get_diagonal_lengths(keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None
+ ) -> np.ndarray:
+ """Calculate the diagonal length of instance bounding box from visible
+ keypoints.
+
+ Args:
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+
+ Returns:
+ np.ndarray: bounding box diagonal length in [N]
+ """
+ pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible)
+ pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2)
+ h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0]
+ diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1))
+
+ return diagonal_length
diff --git a/mmpose/codecs/utils/offset_heatmap.py b/mmpose/codecs/utils/offset_heatmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3c1c32ed391982fa0f8cd31b6240363b4fe1c52
--- /dev/null
+++ b/mmpose/codecs/utils/offset_heatmap.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import Tuple
+
+import numpy as np
+
+
+def generate_offset_heatmap(
+ heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray,
+ radius_factor: float,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Generate offset heatmaps of keypoints, where each keypoint is
+ represented by 3 maps: one pixel-level class label map (1 for keypoint and
+ 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
+ respectively.
+
+ Args:
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ radius_factor (float): The radius factor of the binary label
+ map. The positive region is defined as the neighbor of the
+ keypoint with the radius :math:`r=radius_factor*max(W, H)`
+
+ Returns:
+ tuple:
+ - heatmap (np.ndarray): The generated heatmap in shape
+ (K*3, H, W) where [W, H] is the `heatmap_size`
+ - keypoint_weights (np.ndarray): The target weights in shape
+ (K,)
+ """
+
+ N, K, _ = keypoints.shape
+ W, H = heatmap_size
+
+ heatmaps = np.zeros((K, 3, H, W), dtype=np.float32)
+ keypoint_weights = keypoints_visible.copy()
+
+ # xy grid
+ x = np.arange(0, W, 1)
+ y = np.arange(0, H, 1)[:, None]
+
+ # positive area radius in the classification map
+ radius = radius_factor * max(W, H)
+
+ for n, k in product(range(N), range(K)):
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ mu = keypoints[n, k]
+
+ x_offset = (mu[0] - x) / radius
+ y_offset = (mu[1] - y) / radius
+
+ heatmaps[k, 0] = np.where(x_offset**2 + y_offset**2 <= 1, 1., 0.)
+ heatmaps[k, 1] = x_offset
+ heatmaps[k, 2] = y_offset
+
+ heatmaps = heatmaps.reshape(K * 3, H, W)
+
+ return heatmaps, keypoint_weights
+
+
+def generate_displacement_heatmap(
+ heatmap_size: Tuple[int, int],
+ keypoints: np.ndarray,
+ keypoints_visible: np.ndarray,
+ roots: np.ndarray,
+ roots_visible: np.ndarray,
+ diagonal_lengths: np.ndarray,
+ radius: float,
+):
+ """Generate displacement heatmaps of keypoints, where each keypoint is
+ represented by 3 maps: one pixel-level class label map (1 for keypoint and
+ 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions
+ respectively.
+
+ Args:
+ heatmap_size (Tuple[int, int]): Heatmap size in [W, H]
+ keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ keypoints_visible (np.ndarray): Keypoint visibilities in shape
+ (N, K)
+ roots (np.ndarray): Coordinates of instance centers in shape (N, D).
+ The displacement fields of each instance will locate around its
+ center.
+ roots_visible (np.ndarray): Roots visibilities in shape (N,)
+ diagonal_lengths (np.ndarray): Diaginal length of the bounding boxes
+ of each instance in shape (N,)
+ radius (float): The radius factor of the binary label
+ map. The positive region is defined as the neighbor of the
+ keypoint with the radius :math:`r=radius_factor*max(W, H)`
+
+ Returns:
+ tuple:
+ - displacements (np.ndarray): The generated displacement map in
+ shape (K*2, H, W) where [W, H] is the `heatmap_size`
+ - displacement_weights (np.ndarray): The target weights in shape
+ (K*2, H, W)
+ """
+ N, K, _ = keypoints.shape
+ W, H = heatmap_size
+
+ displacements = np.zeros((K * 2, H, W), dtype=np.float32)
+ displacement_weights = np.zeros((K * 2, H, W), dtype=np.float32)
+ instance_size_map = np.zeros((H, W), dtype=np.float32)
+
+ for n in range(N):
+ if (roots_visible[n] < 1 or (roots[n, 0] < 0 or roots[n, 1] < 0)
+ or (roots[n, 0] >= W or roots[n, 1] >= H)):
+ continue
+
+ diagonal_length = diagonal_lengths[n]
+
+ for k in range(K):
+ if keypoints_visible[n, k] < 1 or keypoints[n, k, 0] < 0 \
+ or keypoints[n, k, 1] < 0 or keypoints[n, k, 0] >= W \
+ or keypoints[n, k, 1] >= H:
+ continue
+
+ start_x = max(int(roots[n, 0] - radius), 0)
+ start_y = max(int(roots[n, 1] - radius), 0)
+ end_x = min(int(roots[n, 0] + radius), W)
+ end_y = min(int(roots[n, 1] + radius), H)
+
+ for x in range(start_x, end_x):
+ for y in range(start_y, end_y):
+ if displacements[2 * k, y,
+ x] != 0 or displacements[2 * k + 1, y,
+ x] != 0:
+ if diagonal_length > instance_size_map[y, x]:
+ # keep the gt displacement of smaller instance
+ continue
+
+ displacement_weights[2 * k:2 * k + 2, y,
+ x] = 1 / diagonal_length
+ displacements[2 * k:2 * k + 2, y,
+ x] = keypoints[n, k] - [x, y]
+ instance_size_map[y, x] = diagonal_length
+
+ return displacements, displacement_weights
diff --git a/mmpose/codecs/utils/post_processing.py b/mmpose/codecs/utils/post_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..75356388dc408d8dda0a72324aa16c3b4f3b6068
--- /dev/null
+++ b/mmpose/codecs/utils/post_processing.py
@@ -0,0 +1,227 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import Tuple
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def get_simcc_normalized(batch_pred_simcc, sigma=None):
+ """Normalize the predicted SimCC.
+
+ Args:
+ batch_pred_simcc (torch.Tensor): The predicted SimCC.
+ sigma (float): The sigma of the Gaussian distribution.
+
+ Returns:
+ torch.Tensor: The normalized SimCC.
+ """
+ B, K, _ = batch_pred_simcc.shape
+
+ # Scale and clamp the tensor
+ if sigma is not None:
+ batch_pred_simcc = batch_pred_simcc / (sigma * np.sqrt(np.pi * 2))
+ batch_pred_simcc = batch_pred_simcc.clamp(min=0)
+
+ # Compute the binary mask
+ mask = (batch_pred_simcc.amax(dim=-1) > 1).reshape(B, K, 1)
+
+ # Normalize the tensor using the maximum value
+ norm = (batch_pred_simcc / batch_pred_simcc.amax(dim=-1).reshape(B, K, 1))
+
+ # Apply normalization
+ batch_pred_simcc = torch.where(mask, norm, batch_pred_simcc)
+
+ return batch_pred_simcc
+
+
+def get_simcc_maximum(simcc_x: np.ndarray,
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from simcc representations.
+
+ Note:
+ instance number: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (N, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (N, K)
+ """
+
+ assert isinstance(simcc_x, np.ndarray), ('simcc_x should be numpy.ndarray')
+ assert isinstance(simcc_y, np.ndarray), ('simcc_y should be numpy.ndarray')
+ assert simcc_x.ndim == 2 or simcc_x.ndim == 3, (
+ f'Invalid shape {simcc_x.shape}')
+ assert simcc_y.ndim == 2 or simcc_y.ndim == 3, (
+ f'Invalid shape {simcc_y.shape}')
+ assert simcc_x.ndim == simcc_y.ndim, (
+ f'{simcc_x.shape} != {simcc_y.shape}')
+
+ if simcc_x.ndim == 3:
+ N, K, Wx = simcc_x.shape
+ simcc_x = simcc_x.reshape(N * K, -1)
+ simcc_y = simcc_y.reshape(N * K, -1)
+ else:
+ N = None
+
+ x_locs = np.argmax(simcc_x, axis=1)
+ y_locs = np.argmax(simcc_y, axis=1)
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ max_val_x = np.amax(simcc_x, axis=1)
+ max_val_y = np.amax(simcc_y, axis=1)
+
+ mask = max_val_x > max_val_y
+ max_val_x[mask] = max_val_y[mask]
+ vals = max_val_x
+ locs[vals <= 0.] = -1
+
+ if N:
+ locs = locs.reshape(N, K, 2)
+ vals = vals.reshape(N, K)
+
+ return locs, vals
+
+
+def get_heatmap_maximum(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from heatmaps.
+
+ Note:
+ batch_size: B
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W)
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (B, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (B, K)
+ """
+ assert isinstance(heatmaps,
+ np.ndarray), ('heatmaps should be numpy.ndarray')
+ assert heatmaps.ndim == 3 or heatmaps.ndim == 4, (
+ f'Invalid shape {heatmaps.shape}')
+
+ if heatmaps.ndim == 3:
+ K, H, W = heatmaps.shape
+ B = None
+ heatmaps_flatten = heatmaps.reshape(K, -1)
+ else:
+ B, K, H, W = heatmaps.shape
+ heatmaps_flatten = heatmaps.reshape(B * K, -1)
+
+ y_locs, x_locs = np.unravel_index(
+ np.argmax(heatmaps_flatten, axis=1), shape=(H, W))
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ vals = np.amax(heatmaps_flatten, axis=1)
+ locs[vals <= 0.] = -1
+
+ if B:
+ locs = locs.reshape(B, K, 2)
+ vals = vals.reshape(B, K)
+
+ return locs, vals
+
+
+def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray:
+ """Modulate heatmap distribution with Gaussian.
+
+ Note:
+ - num_keypoints: K
+ - heatmap height: H
+ - heatmap width: W
+
+ Args:
+ heatmaps (np.ndarray[K, H, W]): model predicted heatmaps.
+ kernel (int): Gaussian kernel size (K) for modulation, which should
+ match the heatmap gaussian sigma when training.
+ K=17 for sigma=3 and k=11 for sigma=2.
+
+ Returns:
+ np.ndarray ([K, H, W]): Modulated heatmap distribution.
+ """
+ assert kernel % 2 == 1
+
+ border = (kernel - 1) // 2
+ K, H, W = heatmaps.shape
+
+ for k in range(K):
+ origin_max = np.max(heatmaps[k])
+ dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
+ dr[border:-border, border:-border] = heatmaps[k].copy()
+ dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
+ heatmaps[k] = dr[border:-border, border:-border].copy()
+ heatmaps[k] *= origin_max / np.max(heatmaps[k])
+ return heatmaps
+
+
+def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray:
+ """Modulate simcc distribution with Gaussian.
+
+ Note:
+ - num_keypoints: K
+ - simcc length: Wx
+
+ Args:
+ simcc (np.ndarray[K, Wx]): model predicted simcc.
+ kernel (int): Gaussian kernel size (K) for modulation, which should
+ match the simcc gaussian sigma when training.
+ K=17 for sigma=3 and k=11 for sigma=2.
+
+ Returns:
+ np.ndarray ([K, Wx]): Modulated simcc distribution.
+ """
+ assert kernel % 2 == 1
+
+ border = (kernel - 1) // 2
+ N, K, Wx = simcc.shape
+
+ for n, k in product(range(N), range(K)):
+ origin_max = np.max(simcc[n, k])
+ dr = np.zeros((1, Wx + 2 * border), dtype=np.float32)
+ dr[0, border:-border] = simcc[n, k].copy()
+ dr = cv2.GaussianBlur(dr, (kernel, 1), 0)
+ simcc[n, k] = dr[0, border:-border].copy()
+ simcc[n, k] *= origin_max / np.max(simcc[n, k])
+ return simcc
+
+
+def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5):
+ """Apply NMS on a batch of heatmaps.
+
+ Args:
+ batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W)
+ kernel_size (int): The kernel size of the NMS which should be
+ a odd integer. Defaults to 5
+
+ Returns:
+ Tensor: The batch heatmaps after NMS.
+ """
+
+ assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \
+ f'The kernel_size should be an odd integer, got {kernel_size}'
+
+ padding = (kernel_size - 1) // 2
+
+ maximum = F.max_pool2d(
+ batch_heatmaps, kernel_size, stride=1, padding=padding)
+ maximum_indicator = torch.eq(batch_heatmaps, maximum)
+ batch_heatmaps = batch_heatmaps * maximum_indicator.float()
+
+ return batch_heatmaps
diff --git a/mmpose/codecs/utils/refinement.py b/mmpose/codecs/utils/refinement.py
new file mode 100644
index 0000000000000000000000000000000000000000..3495f37d0acf62f860b76609e4ddfd35737c22d8
--- /dev/null
+++ b/mmpose/codecs/utils/refinement.py
@@ -0,0 +1,215 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+
+import numpy as np
+
+from .post_processing import gaussian_blur, gaussian_blur1d
+
+
+def refine_keypoints(keypoints: np.ndarray,
+ heatmaps: np.ndarray) -> np.ndarray:
+ """Refine keypoint predictions by moving from the maximum towards the
+ second maximum by 0.25 pixel. The operation is in-place.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - heatmap size: [W, H]
+
+ Args:
+ keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D)
+ heatmaps (np.ndarray): The heatmaps in shape (K, H, W)
+
+ Returns:
+ np.ndarray: Refine keypoint coordinates in shape (N, K, D)
+ """
+ N, K = keypoints.shape[:2]
+ H, W = heatmaps.shape[1:]
+
+ for n, k in product(range(N), range(K)):
+ x, y = keypoints[n, k, :2].astype(int)
+
+ if 1 < x < W - 1 and 0 < y < H:
+ dx = heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1]
+ else:
+ dx = 0.
+
+ if 1 < y < H - 1 and 0 < x < W:
+ dy = heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x]
+ else:
+ dy = 0.
+
+ keypoints[n, k] += np.sign([dx, dy], dtype=np.float32) * 0.25
+
+ return keypoints
+
+
+def refine_keypoints_dark(keypoints: np.ndarray, heatmaps: np.ndarray,
+ blur_kernel_size: int) -> np.ndarray:
+ """Refine keypoint predictions using distribution aware coordinate
+ decoding. See `Dark Pose`_ for details. The operation is in-place.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - heatmap size: [W, H]
+
+ Args:
+ keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D)
+ heatmaps (np.ndarray): The heatmaps in shape (K, H, W)
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation
+
+ Returns:
+ np.ndarray: Refine keypoint coordinates in shape (N, K, D)
+
+ .. _`Dark Pose`: https://arxiv.org/abs/1910.06278
+ """
+ N, K = keypoints.shape[:2]
+ H, W = heatmaps.shape[1:]
+
+ # modulate heatmaps
+ heatmaps = gaussian_blur(heatmaps, blur_kernel_size)
+ np.maximum(heatmaps, 1e-10, heatmaps)
+ np.log(heatmaps, heatmaps)
+
+ for n, k in product(range(N), range(K)):
+ x, y = keypoints[n, k, :2].astype(int)
+ if 1 < x < W - 2 and 1 < y < H - 2:
+ dx = 0.5 * (heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1])
+ dy = 0.5 * (heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x])
+
+ dxx = 0.25 * (
+ heatmaps[k, y, x + 2] - 2 * heatmaps[k, y, x] +
+ heatmaps[k, y, x - 2])
+ dxy = 0.25 * (
+ heatmaps[k, y + 1, x + 1] - heatmaps[k, y - 1, x + 1] -
+ heatmaps[k, y + 1, x - 1] + heatmaps[k, y - 1, x - 1])
+ dyy = 0.25 * (
+ heatmaps[k, y + 2, x] - 2 * heatmaps[k, y, x] +
+ heatmaps[k, y - 2, x])
+ derivative = np.array([[dx], [dy]])
+ hessian = np.array([[dxx, dxy], [dxy, dyy]])
+ if dxx * dyy - dxy**2 != 0:
+ hessianinv = np.linalg.inv(hessian)
+ offset = -hessianinv @ derivative
+ offset = np.squeeze(np.array(offset.T), axis=0)
+ keypoints[n, k, :2] += offset
+ return keypoints
+
+
+def refine_keypoints_dark_udp(keypoints: np.ndarray, heatmaps: np.ndarray,
+ blur_kernel_size: int) -> np.ndarray:
+ """Refine keypoint predictions using distribution aware coordinate decoding
+ for UDP. See `UDP`_ for details. The operation is in-place.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - heatmap size: [W, H]
+
+ Args:
+ keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D)
+ heatmaps (np.ndarray): The heatmaps in shape (K, H, W)
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation
+
+ Returns:
+ np.ndarray: Refine keypoint coordinates in shape (N, K, D)
+
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
+ """
+ N, K = keypoints.shape[:2]
+ H, W = heatmaps.shape[1:]
+
+ # modulate heatmaps
+ heatmaps = gaussian_blur(heatmaps, blur_kernel_size)
+ np.clip(heatmaps, 1e-3, 50., heatmaps)
+ np.log(heatmaps, heatmaps)
+
+ heatmaps_pad = np.pad(
+ heatmaps, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
+
+ for n in range(N):
+ index = keypoints[n, :, 0] + 1 + (keypoints[n, :, 1] + 1) * (W + 2)
+ index += (W + 2) * (H + 2) * np.arange(0, K)
+ index = index.astype(int).reshape(-1, 1)
+ i_ = heatmaps_pad[index]
+ ix1 = heatmaps_pad[index + 1]
+ iy1 = heatmaps_pad[index + W + 2]
+ ix1y1 = heatmaps_pad[index + W + 3]
+ ix1_y1_ = heatmaps_pad[index - W - 3]
+ ix1_ = heatmaps_pad[index - 1]
+ iy1_ = heatmaps_pad[index - 2 - W]
+
+ dx = 0.5 * (ix1 - ix1_)
+ dy = 0.5 * (iy1 - iy1_)
+ derivative = np.concatenate([dx, dy], axis=1)
+ derivative = derivative.reshape(K, 2, 1)
+
+ dxx = ix1 - 2 * i_ + ix1_
+ dyy = iy1 - 2 * i_ + iy1_
+ dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
+ hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
+ hessian = hessian.reshape(K, 2, 2)
+ hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
+ keypoints[n] -= np.einsum('imn,ink->imk', hessian,
+ derivative).squeeze()
+
+ return keypoints
+
+
+def refine_simcc_dark(keypoints: np.ndarray, simcc: np.ndarray,
+ blur_kernel_size: int) -> np.ndarray:
+ """SimCC version. Refine keypoint predictions using distribution aware
+ coordinate decoding for UDP. See `UDP`_ for details. The operation is in-
+ place.
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+
+ Args:
+ keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D)
+ simcc (np.ndarray): The heatmaps in shape (N, K, Wx)
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
+ modulation
+
+ Returns:
+ np.ndarray: Refine keypoint coordinates in shape (N, K, D)
+
+ .. _`UDP`: https://arxiv.org/abs/1911.07524
+ """
+ N = simcc.shape[0]
+
+ # modulate simcc
+ simcc = gaussian_blur1d(simcc, blur_kernel_size)
+ np.clip(simcc, 1e-3, 50., simcc)
+ np.log(simcc, simcc)
+
+ simcc = np.pad(simcc, ((0, 0), (0, 0), (2, 2)), 'edge')
+
+ for n in range(N):
+ px = (keypoints[n] + 2.5).astype(np.int64).reshape(-1, 1) # K, 1
+
+ dx0 = np.take_along_axis(simcc[n], px, axis=1) # K, 1
+ dx1 = np.take_along_axis(simcc[n], px + 1, axis=1)
+ dx_1 = np.take_along_axis(simcc[n], px - 1, axis=1)
+ dx2 = np.take_along_axis(simcc[n], px + 2, axis=1)
+ dx_2 = np.take_along_axis(simcc[n], px - 2, axis=1)
+
+ dx = 0.5 * (dx1 - dx_1)
+ dxx = 1e-9 + 0.25 * (dx2 - 2 * dx0 + dx_2)
+
+ offset = dx / dxx
+ keypoints[n] -= offset.reshape(-1)
+
+ return keypoints
diff --git a/mmpose/datasets/__init__.py b/mmpose/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90a12db4937ffca9ff103b1e5a0c7604de52e0b
--- /dev/null
+++ b/mmpose/datasets/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import build_dataset
+from .dataset_wrappers import CombinedDataset
+from .datasets import * # noqa
+from .samplers import MultiSourceSampler
+from .transforms import * # noqa
+
+__all__ = ['build_dataset', 'CombinedDataset', 'MultiSourceSampler']
diff --git a/mmpose/datasets/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14ad2c4a8c94b4ef8fbd4f2f555faf4e87ebb6ce
Binary files /dev/null and b/mmpose/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/__pycache__/builder.cpython-38.pyc b/mmpose/datasets/__pycache__/builder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd5c0867baf23dc53275361f3623148f3c0c1656
Binary files /dev/null and b/mmpose/datasets/__pycache__/builder.cpython-38.pyc differ
diff --git a/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc b/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd82e726c759c40f56013c0a96a63f178bb3e06a
Binary files /dev/null and b/mmpose/datasets/__pycache__/dataset_wrappers.cpython-38.pyc differ
diff --git a/mmpose/datasets/__pycache__/samplers.cpython-38.pyc b/mmpose/datasets/__pycache__/samplers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b758d8b51ab4fb08b319f36d3542c84f27711646
Binary files /dev/null and b/mmpose/datasets/__pycache__/samplers.cpython-38.pyc differ
diff --git a/mmpose/datasets/builder.py b/mmpose/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e5a236ff49b70b86149d318cbccdfd5af5a6450
--- /dev/null
+++ b/mmpose/datasets/builder.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import platform
+import random
+
+import numpy as np
+import torch
+from mmengine import build_from_cfg, is_seq_of
+from mmengine.dataset import ConcatDataset, RepeatDataset
+
+from mmpose.registry import DATASETS
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ base_soft_limit = rlimit[0]
+ hard_limit = rlimit[1]
+ soft_limit = min(max(4096, base_soft_limit), hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+
+def _concat_dataset(cfg, default_args=None):
+ types = cfg['type']
+ ann_files = cfg['ann_file']
+ img_prefixes = cfg.get('img_prefix', None)
+ dataset_infos = cfg.get('dataset_info', None)
+
+ num_joints = cfg['data_cfg'].get('num_joints', None)
+ dataset_channel = cfg['data_cfg'].get('dataset_channel', None)
+
+ datasets = []
+ num_dset = len(ann_files)
+ for i in range(num_dset):
+ cfg_copy = copy.deepcopy(cfg)
+ cfg_copy['ann_file'] = ann_files[i]
+
+ if isinstance(types, (list, tuple)):
+ cfg_copy['type'] = types[i]
+ if isinstance(img_prefixes, (list, tuple)):
+ cfg_copy['img_prefix'] = img_prefixes[i]
+ if isinstance(dataset_infos, (list, tuple)):
+ cfg_copy['dataset_info'] = dataset_infos[i]
+
+ if isinstance(num_joints, (list, tuple)):
+ cfg_copy['data_cfg']['num_joints'] = num_joints[i]
+
+ if is_seq_of(dataset_channel, list):
+ cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i]
+
+ datasets.append(build_dataset(cfg_copy, default_args))
+
+ return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+ """Build a dataset from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ default_args (dict, optional): Default initialization arguments.
+ Default: None.
+
+ Returns:
+ Dataset: The constructed dataset.
+ """
+
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'ConcatDataset':
+ dataset = ConcatDataset(
+ [build_dataset(c, default_args) for c in cfg['datasets']])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif isinstance(cfg.get('ann_file'), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+ return dataset
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Init the random seed for various workers."""
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..28eeac9945199c08d0de89b5348511c3caae790d
--- /dev/null
+++ b/mmpose/datasets/dataset_wrappers.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from copy import deepcopy
+from typing import Any, Callable, List, Tuple, Union
+
+from mmengine.dataset import BaseDataset
+from mmengine.registry import build_from_cfg
+
+from mmpose.registry import DATASETS
+from .datasets.utils import parse_pose_metainfo
+
+
+@DATASETS.register_module()
+class CombinedDataset(BaseDataset):
+ """A wrapper of combined dataset.
+
+ Args:
+ metainfo (dict): The meta information of combined dataset.
+ datasets (list): The configs of datasets to be combined.
+ pipeline (list, optional): Processing pipeline. Defaults to [].
+ """
+
+ def __init__(self,
+ metainfo: dict,
+ datasets: list,
+ pipeline: List[Union[dict, Callable]] = [],
+ **kwargs):
+
+ self.datasets = []
+
+ for cfg in datasets:
+ dataset = build_from_cfg(cfg, DATASETS)
+ self.datasets.append(dataset)
+
+ self._lens = [len(dataset) for dataset in self.datasets]
+ self._len = sum(self._lens)
+
+ super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs)
+ self._metainfo = parse_pose_metainfo(metainfo)
+
+ @property
+ def metainfo(self):
+ return deepcopy(self._metainfo)
+
+ def __len__(self):
+ return self._len
+
+ def _get_subset_index(self, index: int) -> Tuple[int, int]:
+ """Given a data sample's global index, return the index of the sub-
+ dataset the data sample belongs to, and the local index within that
+ sub-dataset.
+
+ Args:
+ index (int): The global data sample index
+
+ Returns:
+ tuple[int, int]:
+ - subset_index (int): The index of the sub-dataset
+ - local_index (int): The index of the data sample within
+ the sub-dataset
+ """
+ if index >= len(self) or index < -len(self):
+ raise ValueError(
+ f'index({index}) is out of bounds for dataset with '
+ f'length({len(self)}).')
+
+ if index < 0:
+ index = index + len(self)
+
+ subset_index = 0
+ while index >= self._lens[subset_index]:
+ index -= self._lens[subset_index]
+ subset_index += 1
+ return subset_index, index
+
+ def prepare_data(self, idx: int) -> Any:
+ """Get data processed by ``self.pipeline``.The source dataset is
+ depending on the index.
+
+ Args:
+ idx (int): The index of ``data_info``.
+
+ Returns:
+ Any: Depends on ``self.pipeline``.
+ """
+
+ data_info = self.get_data_info(idx)
+
+ return self.pipeline(data_info)
+
+ def get_data_info(self, idx: int) -> dict:
+ """Get annotation by index.
+
+ Args:
+ idx (int): Global index of ``CombinedDataset``.
+ Returns:
+ dict: The idx-th annotation of the datasets.
+ """
+ subset_idx, sample_idx = self._get_subset_index(idx)
+ # Get data sample processed by ``subset.pipeline``
+ data_info = self.datasets[subset_idx][sample_idx]
+
+ # Add metainfo items that are required in the pipeline and the model
+ metainfo_keys = [
+ 'upper_body_ids', 'lower_body_ids', 'flip_pairs',
+ 'dataset_keypoint_weights', 'flip_indices'
+ ]
+
+ for key in metainfo_keys:
+ data_info[key] = deepcopy(self._metainfo[key])
+
+ return data_info
+
+ def full_init(self):
+ """Fully initialize all sub datasets."""
+
+ if self._fully_initialized:
+ return
+
+ for dataset in self.datasets:
+ dataset.full_init()
+ self._fully_initialized = True
diff --git a/mmpose/datasets/datasets/__init__.py b/mmpose/datasets/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03a0f493ca197fe2d6004ed0eb91af1d3e693524
--- /dev/null
+++ b/mmpose/datasets/datasets/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .animal import * # noqa: F401, F403
+from .base import * # noqa: F401, F403
+from .body import * # noqa: F401, F403
+from .face import * # noqa: F401, F403
+from .fashion import * # noqa: F401, F403
+from .hand import * # noqa: F401, F403
+from .wholebody import * # noqa: F401, F403
diff --git a/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b506b9edca6cc639a78b51e976890787716250e0
Binary files /dev/null and b/mmpose/datasets/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc b/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e472c408ae95f858609ecc38664dacdaaf88b9e
Binary files /dev/null and b/mmpose/datasets/datasets/__pycache__/utils.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__init__.py b/mmpose/datasets/datasets/animal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe9b5938c6c16241727922ef1e4e8c91eb058aa
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .animalpose_dataset import AnimalPoseDataset
+from .ap10k_dataset import AP10KDataset
+from .atrw_dataset import ATRWDataset
+from .fly_dataset import FlyDataset
+from .horse10_dataset import Horse10Dataset
+from .locust_dataset import LocustDataset
+from .macaque_dataset import MacaqueDataset
+from .zebra_dataset import ZebraDataset
+
+__all__ = [
+ 'AnimalPoseDataset', 'AP10KDataset', 'Horse10Dataset', 'MacaqueDataset',
+ 'FlyDataset', 'LocustDataset', 'ZebraDataset', 'ATRWDataset'
+]
diff --git a/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df0a3bdcdade8eea3dab047352a47b0bc82b3acd
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bda9277b9a423ee27ce709ea4369c3170560e978
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/animalpose_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..879e1a96fbafdb61eb0fb816222efbf05badecc0
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/ap10k_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92149fa8b6a61770d8bca4486ae12467d4c967d5
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/atrw_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..579409003837ce6c27375d159242bfce39ef1a78
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/fly_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79780b41be5a7125fd72e0e629437e2805b39451
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/horse10_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..256924df7a26dc2ba4dabb9336b322bafbc3fee1
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/locust_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4efa67a70c8ee2f9b7a1b87f35c5f9642d650f59
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/macaque_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc b/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df79f7f8042dd4aefcc3e7c7aace026769cf2396
Binary files /dev/null and b/mmpose/datasets/datasets/animal/__pycache__/zebra_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/animal/animalpose_dataset.py b/mmpose/datasets/datasets/animal/animalpose_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0279cf9de0907626f2a6686170dc5e99aafa2d9d
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/animalpose_dataset.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class AnimalPoseDataset(BaseCocoStyleDataset):
+ """Animal-Pose dataset for animal pose estimation.
+
+ "Cross-domain Adaptation For Animal Pose Estimation" ICCV'2019
+ More details can be found in the `paper
+ `__ .
+
+ Animal-Pose keypoints::
+
+ 0: 'L_Eye',
+ 1: 'R_Eye',
+ 2: 'L_EarBase',
+ 3: 'R_EarBase',
+ 4: 'Nose',
+ 5: 'Throat',
+ 6: 'TailBase',
+ 7: 'Withers',
+ 8: 'L_F_Elbow',
+ 9: 'R_F_Elbow',
+ 10: 'L_B_Elbow',
+ 11: 'R_B_Elbow',
+ 12: 'L_F_Knee',
+ 13: 'R_F_Knee',
+ 14: 'L_B_Knee',
+ 15: 'R_B_Knee',
+ 16: 'L_F_Paw',
+ 17: 'R_F_Paw',
+ 18: 'L_B_Paw',
+ 19: 'R_B_Paw'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/animalpose.py')
diff --git a/mmpose/datasets/datasets/animal/ap10k_dataset.py b/mmpose/datasets/datasets/animal/ap10k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1efbc67f7be55c57532684174442a3f865d5fd
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/ap10k_dataset.py
@@ -0,0 +1,73 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class AP10KDataset(BaseCocoStyleDataset):
+ """AP-10K dataset for animal pose estimation.
+
+ "AP-10K: A Benchmark for Animal Pose Estimation in the Wild"
+ Neurips Dataset Track'2021.
+ More details can be found in the `paper
+ `__ .
+
+ AP-10K keypoints::
+
+ 0: 'L_Eye',
+ 1: 'R_Eye',
+ 2: 'Nose',
+ 3: 'Neck',
+ 4: 'root of tail',
+ 5: 'L_Shoulder',
+ 6: 'L_Elbow',
+ 7: 'L_F_Paw',
+ 8: 'R_Shoulder',
+ 9: 'R_Elbow',
+ 10: 'R_F_Paw,
+ 11: 'L_Hip',
+ 12: 'L_Knee',
+ 13: 'L_B_Paw',
+ 14: 'R_Hip',
+ 15: 'R_Knee',
+ 16: 'R_B_Paw'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/ap10k.py')
diff --git a/mmpose/datasets/datasets/animal/atrw_dataset.py b/mmpose/datasets/datasets/animal/atrw_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..de5b1a09a0510969ea0a6d57c15e5bd13104b99b
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/atrw_dataset.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class ATRWDataset(BaseCocoStyleDataset):
+ """ATRW dataset for animal pose estimation.
+
+ "ATRW: A Benchmark for Amur Tiger Re-identification in the Wild"
+ ACM MM'2020.
+ More details can be found in the `paper
+ `__ .
+
+ ATRW keypoints::
+
+ 0: "left_ear",
+ 1: "right_ear",
+ 2: "nose",
+ 3: "right_shoulder",
+ 4: "right_front_paw",
+ 5: "left_shoulder",
+ 6: "left_front_paw",
+ 7: "right_hip",
+ 8: "right_knee",
+ 9: "right_back_paw",
+ 10: "left_hip",
+ 11: "left_knee",
+ 12: "left_back_paw",
+ 13: "tail",
+ 14: "center"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/atrw.py')
diff --git a/mmpose/datasets/datasets/animal/fly_dataset.py b/mmpose/datasets/datasets/animal/fly_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b614d9b9f77b1e2eb7f067ea6cfb21d788857554
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/fly_dataset.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class FlyDataset(BaseCocoStyleDataset):
+ """FlyDataset for animal pose estimation.
+
+ "Fast animal pose estimation using deep neural networks"
+ Nature methods'2019. More details can be found in the `paper
+ `__ .
+
+ Vinegar Fly keypoints::
+
+ 0: "head",
+ 1: "eyeL",
+ 2: "eyeR",
+ 3: "neck",
+ 4: "thorax",
+ 5: "abdomen",
+ 6: "forelegR1",
+ 7: "forelegR2",
+ 8: "forelegR3",
+ 9: "forelegR4",
+ 10: "midlegR1",
+ 11: "midlegR2",
+ 12: "midlegR3",
+ 13: "midlegR4",
+ 14: "hindlegR1",
+ 15: "hindlegR2",
+ 16: "hindlegR3",
+ 17: "hindlegR4",
+ 18: "forelegL1",
+ 19: "forelegL2",
+ 20: "forelegL3",
+ 21: "forelegL4",
+ 22: "midlegL1",
+ 23: "midlegL2",
+ 24: "midlegL3",
+ 25: "midlegL4",
+ 26: "hindlegL1",
+ 27: "hindlegL2",
+ 28: "hindlegL3",
+ 29: "hindlegL4",
+ 30: "wingL",
+ 31: "wingR"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/fly.py')
diff --git a/mmpose/datasets/datasets/animal/horse10_dataset.py b/mmpose/datasets/datasets/animal/horse10_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c25dba6a705045b731bddd176bf20a46c285764
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/horse10_dataset.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class Horse10Dataset(BaseCocoStyleDataset):
+ """Horse10Dataset for animal pose estimation.
+
+ "Pretraining boosts out-of-domain robustness for pose estimation"
+ WACV'2021. More details can be found in the `paper
+ `__ .
+
+ Horse-10 keypoints::
+
+ 0: 'Nose',
+ 1: 'Eye',
+ 2: 'Nearknee',
+ 3: 'Nearfrontfetlock',
+ 4: 'Nearfrontfoot',
+ 5: 'Offknee',
+ 6: 'Offfrontfetlock',
+ 7: 'Offfrontfoot',
+ 8: 'Shoulder',
+ 9: 'Midshoulder',
+ 10: 'Elbow',
+ 11: 'Girth',
+ 12: 'Wither',
+ 13: 'Nearhindhock',
+ 14: 'Nearhindfetlock',
+ 15: 'Nearhindfoot',
+ 16: 'Hip',
+ 17: 'Stifle',
+ 18: 'Offhindhock',
+ 19: 'Offhindfetlock',
+ 20: 'Offhindfoot',
+ 21: 'Ischium'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/horse10.py')
diff --git a/mmpose/datasets/datasets/animal/locust_dataset.py b/mmpose/datasets/datasets/animal/locust_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ada76034db8e9cbc25d68ccd9a430ea62394c74
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/locust_dataset.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class LocustDataset(BaseCocoStyleDataset):
+ """LocustDataset for animal pose estimation.
+
+ "DeepPoseKit, a software toolkit for fast and robust animal
+ pose estimation using deep learning" Elife'2019.
+ More details can be found in the `paper
+ `__ .
+
+ Desert Locust keypoints::
+
+ 0: "head",
+ 1: "neck",
+ 2: "thorax",
+ 3: "abdomen1",
+ 4: "abdomen2",
+ 5: "anttipL",
+ 6: "antbaseL",
+ 7: "eyeL",
+ 8: "forelegL1",
+ 9: "forelegL2",
+ 10: "forelegL3",
+ 11: "forelegL4",
+ 12: "midlegL1",
+ 13: "midlegL2",
+ 14: "midlegL3",
+ 15: "midlegL4",
+ 16: "hindlegL1",
+ 17: "hindlegL2",
+ 18: "hindlegL3",
+ 19: "hindlegL4",
+ 20: "anttipR",
+ 21: "antbaseR",
+ 22: "eyeR",
+ 23: "forelegR1",
+ 24: "forelegR2",
+ 25: "forelegR3",
+ 26: "forelegR4",
+ 27: "midlegR1",
+ 28: "midlegR2",
+ 29: "midlegR3",
+ 30: "midlegR4",
+ 31: "hindlegR1",
+ 32: "hindlegR2",
+ 33: "hindlegR3",
+ 34: "hindlegR4"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/locust.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw Locust annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # get bbox in shape [1, 4], formatted as xywh
+ # use the entire image which is 160x160
+ bbox = np.array([0, 0, 160, 160], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': ann['num_keypoints'],
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/animal/macaque_dataset.py b/mmpose/datasets/datasets/animal/macaque_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..08da981a1a2299efaadaf727b3960e769999fc35
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/macaque_dataset.py
@@ -0,0 +1,74 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class MacaqueDataset(BaseCocoStyleDataset):
+ """MacaquePose dataset for animal pose estimation.
+
+ "MacaquePose: A novel 'in the wild' macaque monkey pose dataset
+ for markerless motion capture" bioRxiv'2020.
+ More details can be found in the `paper
+ `__ .
+
+ Macaque keypoints::
+
+ 0: 'nose',
+ 1: 'left_eye',
+ 2: 'right_eye',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/macaque.py')
diff --git a/mmpose/datasets/datasets/animal/zebra_dataset.py b/mmpose/datasets/datasets/animal/zebra_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b399a8479bcf18b8b33115b4cd703563e1a846d3
--- /dev/null
+++ b/mmpose/datasets/datasets/animal/zebra_dataset.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class ZebraDataset(BaseCocoStyleDataset):
+ """ZebraDataset for animal pose estimation.
+
+ "DeepPoseKit, a software toolkit for fast and robust animal
+ pose estimation using deep learning" Elife'2019.
+ More details can be found in the `paper
+ `__ .
+
+ Zebra keypoints::
+
+ 0: "snout",
+ 1: "head",
+ 2: "neck",
+ 3: "forelegL1",
+ 4: "forelegR1",
+ 5: "hindlegL1",
+ 6: "hindlegR1",
+ 7: "tailbase",
+ 8: "tailtip"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/zebra.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw Zebra annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # get bbox in shape [1, 4], formatted as xywh
+ # use the entire image which is 160x160
+ bbox = np.array([0, 0, 160, 160], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = ann['num_keypoints']
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/base/__init__.py b/mmpose/datasets/datasets/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bb4efb48daaf71d4d29e46834251c8da3ebbb9
--- /dev/null
+++ b/mmpose/datasets/datasets/base/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_coco_style_dataset import BaseCocoStyleDataset
+
+__all__ = ['BaseCocoStyleDataset']
diff --git a/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffb67dad34dba321c2ac3f2e884463515bb33473
Binary files /dev/null and b/mmpose/datasets/datasets/base/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc b/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2860e8f32b3377bd8d2704d6fa334b71420a7e3
Binary files /dev/null and b/mmpose/datasets/datasets/base/__pycache__/base_coco_style_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/base/base_coco_style_dataset.py b/mmpose/datasets/datasets/base/base_coco_style_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b592813d8faa614ef26623421a69b497ba3f982
--- /dev/null
+++ b/mmpose/datasets/datasets/base/base_coco_style_dataset.py
@@ -0,0 +1,458 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+from copy import deepcopy
+from itertools import filterfalse, groupby
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from mmengine.dataset import BaseDataset, force_full_init
+from mmengine.fileio import exists, get_local_path, load
+from mmengine.utils import is_list_of
+from xtcocotools.coco import COCO
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_xywh2xyxy
+from ..utils import parse_pose_metainfo
+
+
+@DATASETS.register_module()
+class BaseCocoStyleDataset(BaseDataset):
+ """Base class for COCO-style datasets.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data.
+ Default: ``dict(img='')``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict()
+
+ def __init__(self,
+ ann_file: str = '',
+ bbox_file: Optional[str] = None,
+ data_mode: str = 'topdown',
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: dict = dict(img=''),
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: List[Union[dict, Callable]] = [],
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000):
+
+ if data_mode not in {'topdown', 'bottomup'}:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid data_mode: '
+ f'{data_mode}. Should be "topdown" or "bottomup".')
+ self.data_mode = data_mode
+
+ if bbox_file:
+ if self.data_mode != 'topdown':
+ raise ValueError(
+ f'{self.__class__.__name__} is set to {self.data_mode}: '
+ 'mode, while "bbox_file" is only '
+ 'supported in topdown mode.')
+
+ if not test_mode:
+ raise ValueError(
+ f'{self.__class__.__name__} has `test_mode==False` '
+ 'while "bbox_file" is only '
+ 'supported when `test_mode==True`.')
+ self.bbox_file = bbox_file
+
+ super().__init__(
+ ann_file=ann_file,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ filter_cfg=filter_cfg,
+ indices=indices,
+ serialize_data=serialize_data,
+ pipeline=pipeline,
+ test_mode=test_mode,
+ lazy_init=lazy_init,
+ max_refetch=max_refetch)
+
+ @classmethod
+ def _load_metainfo(cls, metainfo: dict = None) -> dict:
+ """Collect meta information from the dictionary of meta.
+
+ Args:
+ metainfo (dict): Raw data of pose meta information.
+
+ Returns:
+ dict: Parsed meta information.
+ """
+
+ if metainfo is None:
+ metainfo = deepcopy(cls.METAINFO)
+
+ if not isinstance(metainfo, dict):
+ raise TypeError(
+ f'metainfo should be a dict, but got {type(metainfo)}')
+
+ # parse pose metainfo if it has been assigned
+ if metainfo:
+ metainfo = parse_pose_metainfo(metainfo)
+ return metainfo
+
+ @force_full_init
+ def prepare_data(self, idx) -> Any:
+ """Get data processed by ``self.pipeline``.
+
+ :class:`BaseCocoStyleDataset` overrides this method from
+ :class:`mmengine.dataset.BaseDataset` to add the metainfo into
+ the ``data_info`` before it is passed to the pipeline.
+
+ Args:
+ idx (int): The index of ``data_info``.
+
+ Returns:
+ Any: Depends on ``self.pipeline``.
+ """
+ data_info = self.get_data_info(idx)
+
+ return self.pipeline(data_info)
+
+ def get_data_info(self, idx: int) -> dict:
+ """Get data info by index.
+
+ Args:
+ idx (int): Index of data info.
+
+ Returns:
+ dict: Data info.
+ """
+ data_info = super().get_data_info(idx)
+
+ # Add metainfo items that are required in the pipeline and the model
+ metainfo_keys = [
+ 'upper_body_ids', 'lower_body_ids', 'flip_pairs',
+ 'dataset_keypoint_weights', 'flip_indices', 'skeleton_links'
+ ]
+
+ for key in metainfo_keys:
+ assert key not in data_info, (
+ f'"{key}" is a reserved key for `metainfo`, but already '
+ 'exists in the `data_info`.')
+
+ data_info[key] = deepcopy(self._metainfo[key])
+
+ return data_info
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list from COCO annotation file or person detection result
+ file."""
+
+ if self.bbox_file:
+ data_list = self._load_detection_results()
+ else:
+ instance_list, image_list = self._load_annotations()
+
+ if self.data_mode == 'topdown':
+ data_list = self._get_topdown_data_infos(instance_list)
+ else:
+ data_list = self._get_bottomup_data_infos(
+ instance_list, image_list)
+
+ return data_list
+
+ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
+ """Load data from annotations in COCO format."""
+
+ assert exists(self.ann_file), 'Annotation file does not exist'
+
+ with get_local_path(self.ann_file) as local_path:
+ self.coco = COCO(local_path)
+ # set the metainfo about categories, which is a list of dict
+ # and each dict contains the 'id', 'name', etc. about this category
+ self._metainfo['CLASSES'] = self.coco.loadCats(self.coco.getCatIds())
+
+ instance_list = []
+ image_list = []
+
+ for img_id in self.coco.getImgIds():
+ img = self.coco.loadImgs(img_id)[0]
+ img.update({
+ 'img_id':
+ img_id,
+ 'img_path':
+ osp.join(self.data_prefix['img'], img['file_name']),
+ })
+ image_list.append(img)
+
+ ann_ids = self.coco.getAnnIds(imgIds=img_id)
+ for ann in self.coco.loadAnns(ann_ids):
+
+ instance_info = self.parse_data_info(
+ dict(raw_ann_info=ann, raw_img_info=img))
+
+ # skip invalid instance annotation.
+ if not instance_info:
+ continue
+
+ instance_list.append(instance_info)
+ return instance_list, image_list
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict | None: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ # filter invalid instance
+ if 'bbox' not in ann or 'keypoints' not in ann:
+ return None
+
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ if 'num_keypoints' in ann:
+ num_keypoints = ann['num_keypoints']
+ else:
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img['img_path'],
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann.get('iscrowd', 0),
+ 'segmentation': ann.get('segmentation', None),
+ 'id': ann['id'],
+ 'category_id': ann['category_id'],
+ # store the raw annotation of the instance
+ # it is useful for evaluation without providing ann_file
+ 'raw_ann_info': copy.deepcopy(ann),
+ }
+
+ if 'crowdIndex' in img:
+ data_info['crowd_index'] = img['crowdIndex']
+
+ return data_info
+
+ @staticmethod
+ def _is_valid_instance(data_info: Dict) -> bool:
+ """Check a data info is an instance with valid bbox and keypoint
+ annotations."""
+ # crowd annotation
+ if 'iscrowd' in data_info and data_info['iscrowd']:
+ return False
+ # invalid keypoints
+ if 'num_keypoints' in data_info and data_info['num_keypoints'] == 0:
+ return False
+ # invalid bbox
+ if 'bbox' in data_info:
+ bbox = data_info['bbox'][0]
+ w, h = bbox[2:4] - bbox[:2]
+ if w <= 0 or h <= 0:
+ return False
+ # invalid keypoints
+ if 'keypoints' in data_info:
+ if np.max(data_info['keypoints']) <= 0:
+ return False
+ return True
+
+ def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]:
+ """Organize the data list in top-down mode."""
+ # sanitize data samples
+ data_list_tp = list(filter(self._is_valid_instance, instance_list))
+
+ return data_list_tp
+
+ def _get_bottomup_data_infos(self, instance_list: List[Dict],
+ image_list: List[Dict]) -> List[Dict]:
+ """Organize the data list in bottom-up mode."""
+
+ # bottom-up data list
+ data_list_bu = []
+
+ used_img_ids = set()
+
+ # group instances by img_id
+ for img_id, data_infos in groupby(instance_list,
+ lambda x: x['img_id']):
+ used_img_ids.add(img_id)
+ data_infos = list(data_infos)
+
+ # image data
+ img_path = data_infos[0]['img_path']
+ data_info_bu = {
+ 'img_id': img_id,
+ 'img_path': img_path,
+ }
+
+ for key in data_infos[0].keys():
+ if key not in data_info_bu:
+ seq = [d[key] for d in data_infos]
+ if isinstance(seq[0], np.ndarray):
+ seq = np.concatenate(seq, axis=0)
+ data_info_bu[key] = seq
+
+ # The segmentation annotation of invalid objects will be used
+ # to generate valid region mask in the pipeline.
+ invalid_segs = []
+ for data_info_invalid in filterfalse(self._is_valid_instance,
+ data_infos):
+ if 'segmentation' in data_info_invalid:
+ invalid_segs.append(data_info_invalid['segmentation'])
+ data_info_bu['invalid_segs'] = invalid_segs
+
+ data_list_bu.append(data_info_bu)
+
+ # add images without instance for evaluation
+ if self.test_mode:
+ for img_info in image_list:
+ if img_info['img_id'] not in used_img_ids:
+ data_info_bu = {
+ 'img_id': img_info['img_id'],
+ 'img_path': img_info['img_path'],
+ 'id': list(),
+ 'raw_ann_info': None,
+ }
+ data_list_bu.append(data_info_bu)
+
+ return data_list_bu
+
+ def _load_detection_results(self) -> List[dict]:
+ """Load data from detection results with dummy keypoint annotations."""
+
+ assert exists(self.ann_file), 'Annotation file does not exist'
+ assert exists(self.bbox_file), 'Bbox file does not exist'
+ # load detection results
+ det_results = load(self.bbox_file)
+ assert is_list_of(det_results, dict)
+
+ # load coco annotations to build image id-to-name index
+ with get_local_path(self.ann_file) as local_path:
+ self.coco = COCO(local_path)
+ # set the metainfo about categories, which is a list of dict
+ # and each dict contains the 'id', 'name', etc. about this category
+ self._metainfo['CLASSES'] = self.coco.loadCats(self.coco.getCatIds())
+
+ num_keypoints = self.metainfo['num_keypoints']
+ data_list = []
+ id_ = 0
+ for det in det_results:
+ # remove non-human instances
+ if det['category_id'] != 1:
+ continue
+
+ img = self.coco.loadImgs(det['image_id'])[0]
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ bbox_xywh = np.array(
+ det['bbox'][:4], dtype=np.float32).reshape(1, 4)
+ bbox = bbox_xywh2xyxy(bbox_xywh)
+ bbox_score = np.array(det['score'], dtype=np.float32).reshape(1)
+
+ # use dummy keypoint location and visibility
+ keypoints = np.zeros((1, num_keypoints, 2), dtype=np.float32)
+ keypoints_visible = np.ones((1, num_keypoints), dtype=np.float32)
+
+ data_list.append({
+ 'img_id': det['image_id'],
+ 'img_path': img_path,
+ 'img_shape': (img['height'], img['width']),
+ 'bbox': bbox,
+ 'bbox_score': bbox_score,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'id': id_,
+ })
+
+ id_ += 1
+
+ return data_list
+
+ def filter_data(self) -> List[dict]:
+ """Filter annotations according to filter_cfg. Defaults return full
+ ``data_list``.
+
+ If 'bbox_score_thr` in filter_cfg, the annotation with bbox_score below
+ the threshold `bbox_score_thr` will be filtered out.
+ """
+
+ data_list = self.data_list
+
+ if self.filter_cfg is None:
+ return data_list
+
+ # filter out annotations with a bbox_score below the threshold
+ if 'bbox_score_thr' in self.filter_cfg:
+
+ if self.data_mode != 'topdown':
+ raise ValueError(
+ f'{self.__class__.__name__} is set to {self.data_mode} '
+ 'mode, while "bbox_score_thr" is only supported in '
+ 'topdown mode.')
+
+ thr = self.filter_cfg['bbox_score_thr']
+ data_list = list(
+ filterfalse(lambda ann: ann['bbox_score'] < thr, data_list))
+
+ return data_list
diff --git a/mmpose/datasets/datasets/body/__init__.py b/mmpose/datasets/datasets/body/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4aeef851907f1799bfc1e3507d66c8e7752f32e
--- /dev/null
+++ b/mmpose/datasets/datasets/body/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .aic_dataset import AicDataset
+from .coco_dataset import CocoDataset
+from .crowdpose_dataset import CrowdPoseDataset
+from .jhmdb_dataset import JhmdbDataset
+from .mhp_dataset import MhpDataset
+from .mpii_dataset import MpiiDataset
+from .mpii_trb_dataset import MpiiTrbDataset
+from .ochuman_dataset import OCHumanDataset
+from .posetrack18_dataset import PoseTrack18Dataset
+from .posetrack18_video_dataset import PoseTrack18VideoDataset
+
+__all__ = [
+ 'CocoDataset', 'MpiiDataset', 'MpiiTrbDataset', 'AicDataset',
+ 'CrowdPoseDataset', 'OCHumanDataset', 'MhpDataset', 'PoseTrack18Dataset',
+ 'JhmdbDataset', 'PoseTrack18VideoDataset'
+]
diff --git a/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9614d92c8bf229e22822c14d7d9474c93623b2aa
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3df2dd0883e58fefb3362125b44e5d3979bbee09
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/aic_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f09a57dfe92799d8af7dc978baf6a63871bea2e3
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/coco_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c04a5f70ee71f436fa03a9fbb4c7c7e8ff6a02e1
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/crowdpose_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..655a75a671439353d5d95eb4dda3f36406a973a7
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/jhmdb_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..454497bfa500fdc0953164fbd721fa691d37e8a1
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mhp_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dab035203d1532b1d6c19cef8f736474c8c9d302
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mpii_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..437bcb0d22da4272e8461c21a8cf937fceb93ff9
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/mpii_trb_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67b583aa2a759d453c9357df2b74202b0f71c140
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/ochuman_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88f2951e19f33db3472a98d76675be7ac0cbce1a
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/posetrack18_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc b/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ef9d73814c2fc23c63408e36dbc91f911090edf
Binary files /dev/null and b/mmpose/datasets/datasets/body/__pycache__/posetrack18_video_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/body/aic_dataset.py b/mmpose/datasets/datasets/body/aic_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9c7cccc76fb47b53cd73f3152878e051b442199
--- /dev/null
+++ b/mmpose/datasets/datasets/body/aic_dataset.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class AicDataset(BaseCocoStyleDataset):
+ """AIC dataset for pose estimation.
+
+ "AI Challenger : A Large-scale Dataset for Going Deeper
+ in Image Understanding", arXiv'2017.
+ More details can be found in the `paper
+ `__
+
+ AIC keypoints::
+
+ 0: "right_shoulder",
+ 1: "right_elbow",
+ 2: "right_wrist",
+ 3: "left_shoulder",
+ 4: "left_elbow",
+ 5: "left_wrist",
+ 6: "right_hip",
+ 7: "right_knee",
+ 8: "right_ankle",
+ 9: "left_hip",
+ 10: "left_knee",
+ 11: "left_ankle",
+ 12: "head_top",
+ 13: "neck"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/aic.py')
diff --git a/mmpose/datasets/datasets/body/coco_dataset.py b/mmpose/datasets/datasets/body/coco_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cc971f91f70ba28de1b9ae520d10a2f491eb32b
--- /dev/null
+++ b/mmpose/datasets/datasets/body/coco_dataset.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class CocoDataset(BaseCocoStyleDataset):
+ """COCO dataset for pose estimation.
+
+ "Microsoft COCO: Common Objects in Context", ECCV'2014.
+ More details can be found in the `paper
+ `__ .
+
+ COCO keypoints::
+
+ 0: 'nose',
+ 1: 'left_eye',
+ 2: 'right_eye',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/coco.py')
diff --git a/mmpose/datasets/datasets/body/crowdpose_dataset.py b/mmpose/datasets/datasets/body/crowdpose_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4218708ff27b37dce7992d73695193442207b6d9
--- /dev/null
+++ b/mmpose/datasets/datasets/body/crowdpose_dataset.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class CrowdPoseDataset(BaseCocoStyleDataset):
+ """CrowdPose dataset for pose estimation.
+
+ "CrowdPose: Efficient Crowded Scenes Pose Estimation and
+ A New Benchmark", CVPR'2019.
+ More details can be found in the `paper
+ `__.
+
+ CrowdPose keypoints::
+
+ 0: 'left_shoulder',
+ 1: 'right_shoulder',
+ 2: 'left_elbow',
+ 3: 'right_elbow',
+ 4: 'left_wrist',
+ 5: 'right_wrist',
+ 6: 'left_hip',
+ 7: 'right_hip',
+ 8: 'left_knee',
+ 9: 'right_knee',
+ 10: 'left_ankle',
+ 11: 'right_ankle',
+ 12: 'top_head',
+ 13: 'neck'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/crowdpose.py')
diff --git a/mmpose/datasets/datasets/body/jhmdb_dataset.py b/mmpose/datasets/datasets/body/jhmdb_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d72a7ddc5129af5e1de144853e5389b09465fd8
--- /dev/null
+++ b/mmpose/datasets/datasets/body/jhmdb_dataset.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class JhmdbDataset(BaseCocoStyleDataset):
+ """JhmdbDataset dataset for pose estimation.
+
+ "Towards understanding action recognition", ICCV'2013.
+ More details can be found in the `paper
+ `__
+
+ sub-JHMDB keypoints::
+
+ 0: "neck",
+ 1: "belly",
+ 2: "head",
+ 3: "right_shoulder",
+ 4: "left_shoulder",
+ 5: "right_hip",
+ 6: "left_hip",
+ 7: "right_elbow",
+ 8: "left_elbow",
+ 9: "right_knee",
+ 10: "left_knee",
+ 11: "right_wrist",
+ 12: "left_wrist",
+ 13: "right_ankle",
+ 14: "left_ankle"
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/jhmdb.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ # JHMDB uses matlab format, index is 1-based,
+ # we should first convert to 0-based index
+ x -= 1
+ y -= 1
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ # JHMDB uses matlab format, index is 1-based,
+ # we should first convert to 0-based index
+ keypoints = _keypoints[..., :2] - 1
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann.get('iscrowd', 0),
+ 'segmentation': ann.get('segmentation', None),
+ 'id': ann['id'],
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/body/mhp_dataset.py b/mmpose/datasets/datasets/body/mhp_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d33602536383898c8b65ca48994d33c1616bea
--- /dev/null
+++ b/mmpose/datasets/datasets/body/mhp_dataset.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class MhpDataset(BaseCocoStyleDataset):
+ """MHPv2.0 dataset for pose estimation.
+
+ "Understanding Humans in Crowded Scenes: Deep Nested Adversarial
+ Learning and A New Benchmark for Multi-Human Parsing", ACM MM'2018.
+ More details can be found in the `paper
+ `__
+
+ MHP keypoints::
+
+ 0: "right ankle",
+ 1: "right knee",
+ 2: "right hip",
+ 3: "left hip",
+ 4: "left knee",
+ 5: "left ankle",
+ 6: "pelvis",
+ 7: "thorax",
+ 8: "upper neck",
+ 9: "head top",
+ 10: "right wrist",
+ 11: "right elbow",
+ 12: "right shoulder",
+ 13: "left shoulder",
+ 14: "left elbow",
+ 15: "left wrist",
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/mhp.py')
diff --git a/mmpose/datasets/datasets/body/mpii_dataset.py b/mmpose/datasets/datasets/body/mpii_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..237f1ab2b61a6bda816e626b987de31560fda22a
--- /dev/null
+++ b/mmpose/datasets/datasets/body/mpii_dataset.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os.path as osp
+from typing import Callable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from mmengine.fileio import exists, get_local_path
+from scipy.io import loadmat
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_cs2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class MpiiDataset(BaseCocoStyleDataset):
+ """MPII Dataset for pose estimation.
+
+ "2D Human Pose Estimation: New Benchmark and State of the Art Analysis"
+ ,CVPR'2014. More details can be found in the `paper
+ `__ .
+
+ MPII keypoints::
+
+ 0: 'right_ankle'
+ 1: 'right_knee',
+ 2: 'right_hip',
+ 3: 'left_hip',
+ 4: 'left_knee',
+ 5: 'left_ankle',
+ 6: 'pelvis',
+ 7: 'thorax',
+ 8: 'upper_neck',
+ 9: 'head_top',
+ 10: 'right_wrist',
+ 11: 'right_elbow',
+ 12: 'right_shoulder',
+ 13: 'left_shoulder',
+ 14: 'left_elbow',
+ 15: 'left_wrist'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ headbox_file (str, optional): The path of ``mpii_gt_val.mat`` which
+ provides the headboxes information used for ``PCKh``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/mpii.py')
+
+ def __init__(self,
+ ann_file: str = '',
+ bbox_file: Optional[str] = None,
+ headbox_file: Optional[str] = None,
+ data_mode: str = 'topdown',
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: dict = dict(img=''),
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: List[Union[dict, Callable]] = [],
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000):
+
+ if headbox_file:
+ if data_mode != 'topdown':
+ raise ValueError(
+ f'{self.__class__.__name__} is set to {data_mode}: '
+ 'mode, while "headbox_file" is only '
+ 'supported in topdown mode.')
+
+ if not test_mode:
+ raise ValueError(
+ f'{self.__class__.__name__} has `test_mode==False` '
+ 'while "headbox_file" is only '
+ 'supported when `test_mode==True`.')
+
+ headbox_file_type = headbox_file[-3:]
+ allow_headbox_file_type = ['mat']
+ if headbox_file_type not in allow_headbox_file_type:
+ raise KeyError(
+ f'The head boxes file type {headbox_file_type} is not '
+ f'supported. Should be `mat` but got {headbox_file_type}.')
+ self.headbox_file = headbox_file
+
+ super().__init__(
+ ann_file=ann_file,
+ bbox_file=bbox_file,
+ data_mode=data_mode,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ filter_cfg=filter_cfg,
+ indices=indices,
+ serialize_data=serialize_data,
+ pipeline=pipeline,
+ test_mode=test_mode,
+ lazy_init=lazy_init,
+ max_refetch=max_refetch)
+
+ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
+ """Load data from annotations in MPII format."""
+
+ assert exists(self.ann_file), 'Annotation file does not exist'
+ with get_local_path(self.ann_file) as local_path:
+ with open(local_path) as anno_file:
+ self.anns = json.load(anno_file)
+
+ if self.headbox_file:
+ assert exists(self.headbox_file), 'Headbox file does not exist'
+ with get_local_path(self.headbox_file) as local_path:
+ self.headbox_dict = loadmat(local_path)
+ headboxes_src = np.transpose(self.headbox_dict['headboxes_src'],
+ [2, 0, 1])
+ SC_BIAS = 0.6
+
+ instance_list = []
+ image_list = []
+ used_img_ids = set()
+ ann_id = 0
+
+ # mpii bbox scales are normalized with factor 200.
+ pixel_std = 200.
+
+ for idx, ann in enumerate(self.anns):
+ center = np.array(ann['center'], dtype=np.float32)
+ scale = np.array([ann['scale'], ann['scale']],
+ dtype=np.float32) * pixel_std
+
+ # Adjust center/scale slightly to avoid cropping limbs
+ if center[0] != -1:
+ center[1] = center[1] + 15. / pixel_std * scale[1]
+
+ # MPII uses matlab format, index is 1-based,
+ # we should first convert to 0-based index
+ center = center - 1
+
+ # unify shape with coco datasets
+ center = center.reshape(1, -1)
+ scale = scale.reshape(1, -1)
+ bbox = bbox_cs2xyxy(center, scale)
+
+ # load keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ keypoints = np.array(ann['joints']).reshape(1, -1, 2)
+ keypoints_visible = np.array(ann['joints_vis']).reshape(1, -1)
+
+ instance_info = {
+ 'id': ann_id,
+ 'img_id': int(ann['image'].split('.')[0]),
+ 'img_path': osp.join(self.data_prefix['img'], ann['image']),
+ 'bbox_center': center,
+ 'bbox_scale': scale,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ }
+
+ if self.headbox_file:
+ # calculate the diagonal length of head box as norm_factor
+ headbox = headboxes_src[idx]
+ head_size = np.linalg.norm(headbox[1] - headbox[0], axis=0)
+ head_size *= SC_BIAS
+ instance_info['head_size'] = head_size.reshape(1, -1)
+
+ if instance_info['img_id'] not in used_img_ids:
+ used_img_ids.add(instance_info['img_id'])
+ image_list.append({
+ 'img_id': instance_info['img_id'],
+ 'img_path': instance_info['img_path'],
+ })
+
+ instance_list.append(instance_info)
+ ann_id = ann_id + 1
+
+ return instance_list, image_list
diff --git a/mmpose/datasets/datasets/body/mpii_trb_dataset.py b/mmpose/datasets/datasets/body/mpii_trb_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb96ad876f483dd0b01946030485a53defcf41c8
--- /dev/null
+++ b/mmpose/datasets/datasets/body/mpii_trb_dataset.py
@@ -0,0 +1,169 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os.path as osp
+from typing import List, Tuple
+
+import numpy as np
+from mmengine.fileio import exists, get_local_path
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_cs2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class MpiiTrbDataset(BaseCocoStyleDataset):
+ """MPII-TRB Dataset dataset for pose estimation.
+
+ "TRB: A Novel Triplet Representation for Understanding 2D Human Body",
+ ICCV'2019. More details can be found in the `paper
+ `__ .
+
+ MPII-TRB keypoints::
+
+ 0: 'left_shoulder'
+ 1: 'right_shoulder'
+ 2: 'left_elbow'
+ 3: 'right_elbow'
+ 4: 'left_wrist'
+ 5: 'right_wrist'
+ 6: 'left_hip'
+ 7: 'right_hip'
+ 8: 'left_knee'
+ 9: 'right_knee'
+ 10: 'left_ankle'
+ 11: 'right_ankle'
+ 12: 'head'
+ 13: 'neck'
+
+ 14: 'right_neck'
+ 15: 'left_neck'
+ 16: 'medial_right_shoulder'
+ 17: 'lateral_right_shoulder'
+ 18: 'medial_right_bow'
+ 19: 'lateral_right_bow'
+ 20: 'medial_right_wrist'
+ 21: 'lateral_right_wrist'
+ 22: 'medial_left_shoulder'
+ 23: 'lateral_left_shoulder'
+ 24: 'medial_left_bow'
+ 25: 'lateral_left_bow'
+ 26: 'medial_left_wrist'
+ 27: 'lateral_left_wrist'
+ 28: 'medial_right_hip'
+ 29: 'lateral_right_hip'
+ 30: 'medial_right_knee'
+ 31: 'lateral_right_knee'
+ 32: 'medial_right_ankle'
+ 33: 'lateral_right_ankle'
+ 34: 'medial_left_hip'
+ 35: 'lateral_left_hip'
+ 36: 'medial_left_knee'
+ 37: 'lateral_left_knee'
+ 38: 'medial_left_ankle'
+ 39: 'lateral_left_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/mpii_trb.py')
+
+ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
+ """Load data from annotations in MPII-TRB format."""
+
+ assert exists(self.ann_file), 'Annotation file does not exist'
+ with get_local_path(self.ann_file) as local_path:
+ with open(local_path) as anno_file:
+ self.data = json.load(anno_file)
+
+ imgid2info = {img['id']: img for img in self.data['images']}
+
+ instance_list = []
+ image_list = []
+ used_img_ids = set()
+
+ # mpii-trb bbox scales are normalized with factor 200.
+ pixel_std = 200.
+
+ for ann in self.data['annotations']:
+ img_id = ann['image_id']
+
+ # center, scale in shape [1, 2] and bbox in [1, 4]
+ center = np.array([ann['center']], dtype=np.float32)
+ scale = np.array([[ann['scale'], ann['scale']]],
+ dtype=np.float32) * pixel_std
+ bbox = bbox_cs2xyxy(center, scale)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ img_path = osp.join(self.data_prefix['img'],
+ imgid2info[img_id]['file_name'])
+
+ instance_info = {
+ 'id': ann['id'],
+ 'img_id': img_id,
+ 'img_path': img_path,
+ 'bbox_center': center,
+ 'bbox_scale': scale,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': ann['num_joints'],
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ }
+
+ # val set
+ if 'headbox' in ann:
+ instance_info['headbox'] = np.array(
+ ann['headbox'], dtype=np.float32)
+
+ instance_list.append(instance_info)
+ if instance_info['img_id'] not in used_img_ids:
+ used_img_ids.add(instance_info['img_id'])
+ image_list.append({
+ 'img_id': instance_info['img_id'],
+ 'img_path': instance_info['img_path'],
+ })
+
+ instance_list = sorted(instance_list, key=lambda x: x['id'])
+ return instance_list, image_list
diff --git a/mmpose/datasets/datasets/body/ochuman_dataset.py b/mmpose/datasets/datasets/body/ochuman_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..695d090ea998dd530e0f65f902916107e77c4f6d
--- /dev/null
+++ b/mmpose/datasets/datasets/body/ochuman_dataset.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class OCHumanDataset(BaseCocoStyleDataset):
+ """OChuman dataset for pose estimation.
+
+ "Pose2Seg: Detection Free Human Instance Segmentation", CVPR'2019.
+ More details can be found in the `paper
+ `__ .
+
+ "Occluded Human (OCHuman)" dataset contains 8110 heavily occluded
+ human instances within 4731 images. OCHuman dataset is designed for
+ validation and testing. To evaluate on OCHuman, the model should be
+ trained on COCO training set, and then test the robustness of the
+ model to occlusion using OCHuman.
+
+ OCHuman keypoints (same as COCO)::
+
+ 0: 'nose',
+ 1: 'left_eye',
+ 2: 'right_eye',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/ochuman.py')
diff --git a/mmpose/datasets/datasets/body/posetrack18_dataset.py b/mmpose/datasets/datasets/body/posetrack18_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8110c107f6869085ed795c8f1f0338d2c6ed21d
--- /dev/null
+++ b/mmpose/datasets/datasets/body/posetrack18_dataset.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class PoseTrack18Dataset(BaseCocoStyleDataset):
+ """PoseTrack18 dataset for pose estimation.
+
+ "Posetrack: A benchmark for human pose estimation and tracking", CVPR'2018.
+ More details can be found in the `paper
+ `__ .
+
+ PoseTrack2018 keypoints::
+
+ 0: 'nose',
+ 1: 'head_bottom',
+ 2: 'head_top',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/posetrack18.py')
diff --git a/mmpose/datasets/datasets/body/posetrack18_video_dataset.py b/mmpose/datasets/datasets/body/posetrack18_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc5fe8646c54c6205b7bf6e0e1889dc13db02b02
--- /dev/null
+++ b/mmpose/datasets/datasets/body/posetrack18_video_dataset.py
@@ -0,0 +1,389 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Callable, List, Optional, Sequence, Union
+
+import numpy as np
+from mmengine.fileio import exists, get_local_path, load
+from mmengine.utils import is_list_of
+from xtcocotools.coco import COCO
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_xywh2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class PoseTrack18VideoDataset(BaseCocoStyleDataset):
+ """PoseTrack18 dataset for video pose estimation.
+
+ "Posetrack: A benchmark for human pose estimation and tracking", CVPR'2018.
+ More details can be found in the `paper
+ `__ .
+
+ PoseTrack2018 keypoints::
+
+ 0: 'nose',
+ 1: 'head_bottom',
+ 2: 'head_top',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ frame_weights (List[Union[int, float]] ): The weight of each frame
+ for aggregation. The first weight is for the center frame, then on
+ ascending order of frame indices. Note that the length of
+ ``frame_weights`` should be consistent with the number of sampled
+ frames. Default: [0.0, 1.0]
+ frame_sampler_mode (str): Specifies the mode of frame sampler:
+ ``'fixed'`` or ``'random'``. In ``'fixed'`` mode, each frame
+ index relative to the center frame is fixed, specified by
+ ``frame_indices``, while in ``'random'`` mode, each frame index
+ relative to the center frame is sampled from ``frame_range``
+ with certain randomness. Default: ``'random'``.
+ frame_range (int | List[int], optional): The sampling range of
+ supporting frames in the same video for center frame.
+ Only valid when ``frame_sampler_mode`` is ``'random'``.
+ Default: ``None``.
+ num_sampled_frame(int, optional): The number of sampled frames, except
+ the center frame. Only valid when ``frame_sampler_mode`` is
+ ``'random'``. Default: 1.
+ frame_indices (Sequence[int], optional): The sampled frame indices,
+ including the center frame indicated by 0. Only valid when
+ ``frame_sampler_mode`` is ``'fixed'``. Default: ``None``.
+ ph_fill_len (int): The length of the placeholder to fill in the
+ image filenames. Default: 6
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img='')``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/posetrack18.py')
+
+ def __init__(self,
+ ann_file: str = '',
+ bbox_file: Optional[str] = None,
+ data_mode: str = 'topdown',
+ frame_weights: List[Union[int, float]] = [0.0, 1.0],
+ frame_sampler_mode: str = 'random',
+ frame_range: Optional[Union[int, List[int]]] = None,
+ num_sampled_frame: Optional[int] = None,
+ frame_indices: Optional[Sequence[int]] = None,
+ ph_fill_len: int = 6,
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: dict = dict(img=''),
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: List[Union[dict, Callable]] = [],
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000):
+ assert sum(frame_weights) == 1, 'Invalid `frame_weights`: should sum'\
+ f' to 1.0, but got {frame_weights}.'
+ for weight in frame_weights:
+ assert weight >= 0, 'frame_weight can not be a negative value.'
+ self.frame_weights = np.array(frame_weights)
+
+ if frame_sampler_mode not in {'fixed', 'random'}:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid frame_sampler_mode: '
+ f'{frame_sampler_mode}. Should be `"fixed"` or `"random"`.')
+ self.frame_sampler_mode = frame_sampler_mode
+
+ if frame_sampler_mode == 'random':
+ assert frame_range is not None, \
+ '`frame_sampler_mode` is set as `random`, ' \
+ 'please specify the `frame_range`.'
+
+ if isinstance(frame_range, int):
+ assert frame_range >= 0, \
+ 'frame_range can not be a negative value.'
+ self.frame_range = [-frame_range, frame_range]
+
+ elif isinstance(frame_range, Sequence):
+ assert len(frame_range) == 2, 'The length must be 2.'
+ assert frame_range[0] <= 0 and frame_range[
+ 1] >= 0 and frame_range[1] > frame_range[
+ 0], 'Invalid `frame_range`'
+ for i in frame_range:
+ assert isinstance(i, int), 'Each element must be int.'
+ self.frame_range = frame_range
+ else:
+ raise TypeError(
+ f'The type of `frame_range` must be int or Sequence, '
+ f'but got {type(frame_range)}.')
+
+ assert num_sampled_frame is not None, \
+ '`frame_sampler_mode` is set as `random`, please specify ' \
+ '`num_sampled_frame`, e.g. the number of sampled frames.'
+
+ assert len(frame_weights) == num_sampled_frame + 1, \
+ f'the length of frame_weights({len(frame_weights)}) '\
+ f'does not match the number of sampled adjacent '\
+ f'frames({num_sampled_frame})'
+ self.frame_indices = None
+ self.num_sampled_frame = num_sampled_frame
+
+ if frame_sampler_mode == 'fixed':
+ assert frame_indices is not None, \
+ '`frame_sampler_mode` is set as `fixed`, ' \
+ 'please specify the `frame_indices`.'
+ assert len(frame_weights) == len(frame_indices), \
+ f'the length of frame_weights({len(frame_weights)}) does not '\
+ f'match the length of frame_indices({len(frame_indices)}).'
+ frame_indices.sort()
+ self.frame_indices = frame_indices
+ self.frame_range = None
+ self.num_sampled_frame = None
+
+ self.ph_fill_len = ph_fill_len
+
+ super().__init__(
+ ann_file=ann_file,
+ bbox_file=bbox_file,
+ data_mode=data_mode,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ filter_cfg=filter_cfg,
+ indices=indices,
+ serialize_data=serialize_data,
+ pipeline=pipeline,
+ test_mode=test_mode,
+ lazy_init=lazy_init,
+ max_refetch=max_refetch)
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ # filter invalid instance
+ if 'bbox' not in ann or 'keypoints' not in ann or max(
+ ann['keypoints']) == 0:
+ return None
+
+ img_w, img_h = img['width'], img['height']
+ # get the bbox of the center frame
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # get the keypoints of the center frame
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ # deal with multiple image paths
+ img_paths: list = []
+ # get the image path of the center frame
+ center_img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ # append the center image path first
+ img_paths.append(center_img_path)
+
+ # select the frame indices
+ if self.frame_sampler_mode == 'fixed':
+ indices = self.frame_indices
+ else: # self.frame_sampler_mode == 'random':
+ low, high = self.frame_range
+ indices = np.random.randint(low, high + 1, self.num_sampled_frame)
+
+ nframes = int(img['nframes'])
+ file_name = img['file_name']
+ ref_idx = int(osp.splitext(osp.basename(file_name))[0])
+
+ for idx in indices:
+ if self.test_mode and idx == 0:
+ continue
+ # the supporting frame index
+ support_idx = ref_idx + idx
+ # clip the frame index to make sure that it does not exceed
+ # the boundings of frame indices
+ support_idx = np.clip(support_idx, 0, nframes - 1)
+ sup_img_path = osp.join(
+ osp.dirname(center_img_path),
+ str(support_idx).zfill(self.ph_fill_len) + '.jpg')
+
+ img_paths.append(sup_img_path)
+
+ data_info = {
+ 'img_id': int(img['frame_id']),
+ 'img_path': img_paths,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': ann['num_keypoints'],
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'frame_weights': self.frame_weights,
+ 'id': ann['id'],
+ }
+
+ return data_info
+
+ def _load_detection_results(self) -> List[dict]:
+ """Load data from detection results with dummy keypoint annotations."""
+ assert exists(self.ann_file), 'Annotation file does not exist'
+ assert exists(self.bbox_file), 'Bbox file does not exist'
+
+ # load detection results
+ det_results = load(self.bbox_file)
+ assert is_list_of(det_results, dict)
+
+ # load coco annotations to build image id-to-name index
+ with get_local_path(self.ann_file) as local_path:
+ self.coco = COCO(local_path)
+
+ # mapping image name to id
+ name2id = {}
+ # mapping image id to name
+ id2name = {}
+ for img_id, image in self.coco.imgs.items():
+ file_name = image['file_name']
+ id2name[img_id] = file_name
+ name2id[file_name] = img_id
+
+ num_keypoints = self.metainfo['num_keypoints']
+ data_list = []
+ id_ = 0
+ for det in det_results:
+ # remove non-human instances
+ if det['category_id'] != 1:
+ continue
+
+ # get the predicted bbox and bbox_score
+ bbox_xywh = np.array(
+ det['bbox'][:4], dtype=np.float32).reshape(1, 4)
+ bbox = bbox_xywh2xyxy(bbox_xywh)
+ bbox_score = np.array(det['score'], dtype=np.float32).reshape(1)
+
+ # use dummy keypoint location and visibility
+ keypoints = np.zeros((1, num_keypoints, 2), dtype=np.float32)
+ keypoints_visible = np.ones((1, num_keypoints), dtype=np.float32)
+
+ # deal with different bbox file formats
+ if 'nframes' in det:
+ nframes = int(det['nframes'])
+ else:
+ if 'image_name' in det:
+ img_id = name2id[det['image_name']]
+ else:
+ img_id = det['image_id']
+ img_ann = self.coco.loadImgs(img_id)[0]
+ nframes = int(img_ann['nframes'])
+
+ # deal with multiple image paths
+ img_paths: list = []
+ if 'image_name' in det:
+ image_name = det['image_name']
+ else:
+ image_name = id2name[det['image_id']]
+ # get the image path of the center frame
+ center_img_path = osp.join(self.data_prefix['img'], image_name)
+ # append the center image path first
+ img_paths.append(center_img_path)
+
+ # "images/val/012834_mpii_test/000000.jpg" -->> "000000.jpg"
+ center_image_name = image_name.split('/')[-1]
+ ref_idx = int(center_image_name.replace('.jpg', ''))
+
+ # select the frame indices
+ if self.frame_sampler_mode == 'fixed':
+ indices = self.frame_indices
+ else: # self.frame_sampler_mode == 'random':
+ low, high = self.frame_range
+ indices = np.random.randint(low, high + 1,
+ self.num_sampled_frame)
+
+ for idx in indices:
+ if self.test_mode and idx == 0:
+ continue
+ # the supporting frame index
+ support_idx = ref_idx + idx
+ # clip the frame index to make sure that it does not exceed
+ # the boundings of frame indices
+ support_idx = np.clip(support_idx, 0, nframes - 1)
+ sup_img_path = center_img_path.replace(
+ center_image_name,
+ str(support_idx).zfill(self.ph_fill_len) + '.jpg')
+
+ img_paths.append(sup_img_path)
+
+ data_list.append({
+ 'img_id': det['image_id'],
+ 'img_path': img_paths,
+ 'frame_weights': self.frame_weights,
+ 'bbox': bbox,
+ 'bbox_score': bbox_score,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'id': id_,
+ })
+
+ id_ += 1
+
+ return data_list
diff --git a/mmpose/datasets/datasets/face/__init__.py b/mmpose/datasets/datasets/face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..700cb605f7e5bb5177ce382b0f162edd8959d277
--- /dev/null
+++ b/mmpose/datasets/datasets/face/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .aflw_dataset import AFLWDataset
+from .coco_wholebody_face_dataset import CocoWholeBodyFaceDataset
+from .cofw_dataset import COFWDataset
+from .face_300w_dataset import Face300WDataset
+from .lapa_dataset import LapaDataset
+from .wflw_dataset import WFLWDataset
+
+__all__ = [
+ 'Face300WDataset', 'WFLWDataset', 'AFLWDataset', 'COFWDataset',
+ 'CocoWholeBodyFaceDataset', 'LapaDataset'
+]
diff --git a/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccd906c49009614660108413fc1c45a16d52ddc4
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..724b99e6a5770018a7e1fbc36c72c6bbb7c9012f
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/aflw_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d26f6fe00f0077c32a3c0fc34a7156b1b8b8ec5
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/coco_wholebody_face_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..918a8bd74e3060ed34160702559144e4955cb837
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/cofw_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcbd9ad4e84e84cb6d25de643a75187d93045813
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/face_300w_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46ed2e5ab63ccf931c226c7bf17857963df9e610
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/lapa_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc b/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b92d0794ade8e34977b148a230e8e3469567241
Binary files /dev/null and b/mmpose/datasets/datasets/face/__pycache__/wflw_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/face/aflw_dataset.py b/mmpose/datasets/datasets/face/aflw_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..deda0974bb58ba52371f727e788342b5502987a5
--- /dev/null
+++ b/mmpose/datasets/datasets/face/aflw_dataset.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_cs2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class AFLWDataset(BaseCocoStyleDataset):
+ """AFLW dataset for face keypoint localization.
+
+ "Annotated Facial Landmarks in the Wild: A Large-scale,
+ Real-world Database for Facial Landmark Localization".
+ In Proc. First IEEE International Workshop on Benchmarking
+ Facial Image Analysis Technologies, 2011.
+
+ The landmark annotations follow the 19 points mark-up. The definition
+ can be found in `https://www.tugraz.at/institute/icg/research`
+ `/team-bischof/lrs/downloads/aflw/`
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/aflw.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw Face AFLW annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # aflw bbox scales are normalized with factor 200.
+ pixel_std = 200.
+
+ # center, scale in shape [1, 2] and bbox in [1, 4]
+ center = np.array([ann['center']], dtype=np.float32)
+ scale = np.array([[ann['scale'], ann['scale']]],
+ dtype=np.float32) * pixel_std
+ bbox = bbox_cs2xyxy(center, scale)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = ann['num_keypoints']
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_center': center,
+ 'bbox_scale': scale,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+
+ if self.test_mode:
+ # 'box_size' is used as normalization factor
+ assert 'box_size' in ann, '"box_size" is missing in annotation, '\
+ 'which is required for evaluation.'
+ data_info['box_size'] = ann['box_size']
+
+ return data_info
diff --git a/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py b/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc2c5be386012a341879a3910dcf72e5672e5d6f
--- /dev/null
+++ b/mmpose/datasets/datasets/face/coco_wholebody_face_dataset.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class CocoWholeBodyFaceDataset(BaseCocoStyleDataset):
+ """CocoWholeBodyDataset for face keypoint localization.
+
+ `Whole-Body Human Pose Estimation in the Wild', ECCV'2020.
+ More details can be found in the `paper
+ `__ .
+
+ The face landmark annotations follow the 68 points mark-up.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(
+ from_file='configs/_base_/datasets/coco_wholebody_face.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw CocoWholeBody Face annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ # filter invalid instance
+ if not ann['face_valid'] or max(ann['face_kpts']) <= 0:
+ return None
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['face_box']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['face_kpts'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+ return data_info
diff --git a/mmpose/datasets/datasets/face/cofw_dataset.py b/mmpose/datasets/datasets/face/cofw_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec2a37efd8b7fc125ebd87df88bc9c99cd86250
--- /dev/null
+++ b/mmpose/datasets/datasets/face/cofw_dataset.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class COFWDataset(BaseCocoStyleDataset):
+ """COFW dataset for face keypoint localization.
+
+ "Robust face landmark estimation under occlusion", ICCV'2013.
+
+ The landmark annotations follow the 29 points mark-up. The definition
+ can be found in `http://www.vision.caltech.edu/xpburgos/ICCV13/`__ .
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/cofw.py')
diff --git a/mmpose/datasets/datasets/face/face_300w_dataset.py b/mmpose/datasets/datasets/face/face_300w_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70e892b4f707dc5990566b760e0a2566eb4a53f
--- /dev/null
+++ b/mmpose/datasets/datasets/face/face_300w_dataset.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_cs2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class Face300WDataset(BaseCocoStyleDataset):
+ """300W dataset for face keypoint localization.
+
+ "300 faces In-the-wild challenge: Database and results",
+ Image and Vision Computing (IMAVIS) 2019.
+
+ The landmark annotations follow the 68 points mark-up. The definition
+ can be found in `https://ibug.doc.ic.ac.uk/resources/300-W/`.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/300w.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw Face300W annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # 300w bbox scales are normalized with factor 200.
+ pixel_std = 200.
+
+ # center, scale in shape [1, 2] and bbox in [1, 4]
+ center = np.array([ann['center']], dtype=np.float32)
+ scale = np.array([[ann['scale'], ann['scale']]],
+ dtype=np.float32) * pixel_std
+ bbox = bbox_cs2xyxy(center, scale)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = ann['num_keypoints']
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_center': center,
+ 'bbox_scale': scale,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+ return data_info
diff --git a/mmpose/datasets/datasets/face/lapa_dataset.py b/mmpose/datasets/datasets/face/lapa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5bdc4ec08cebe690ae1f5f2a659e9c087634ec
--- /dev/null
+++ b/mmpose/datasets/datasets/face/lapa_dataset.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class LapaDataset(BaseCocoStyleDataset):
+ """LaPa dataset for face keypoint localization.
+
+ "A New Dataset and Boundary-Attention Semantic Segmentation
+ for Face Parsing", AAAI'2020.
+
+ The landmark annotations follow the 106 points mark-up. The definition
+ can be found in `https://github.com/JDAI-CV/lapa-dataset/`__ .
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/lapa.py')
diff --git a/mmpose/datasets/datasets/face/wflw_dataset.py b/mmpose/datasets/datasets/face/wflw_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c1c23053ce87fc92a234334e637e7a8e0402a9e
--- /dev/null
+++ b/mmpose/datasets/datasets/face/wflw_dataset.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_cs2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class WFLWDataset(BaseCocoStyleDataset):
+ """WFLW dataset for face keypoint localization.
+
+ "Look at Boundary: A Boundary-Aware Face Alignment Algorithm",
+ CVPR'2018.
+
+ The landmark annotations follow the 98 points mark-up. The definition
+ can be found in `https://wywu.github.io/projects/LAB/WFLW.html`__ .
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/wflw.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw Face WFLW annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # wflw bbox scales are normalized with factor 200.
+ pixel_std = 200.
+
+ # center, scale in shape [1, 2] and bbox in [1, 4]
+ center = np.array([ann['center']], dtype=np.float32)
+ scale = np.array([[ann['scale'], ann['scale']]],
+ dtype=np.float32) * pixel_std
+ bbox = bbox_cs2xyxy(center, scale)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = ann['num_keypoints']
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_center': center,
+ 'bbox_scale': scale,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'id': ann['id'],
+ }
+ return data_info
diff --git a/mmpose/datasets/datasets/fashion/__init__.py b/mmpose/datasets/datasets/fashion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8be25dede3d16dfb7754c794d86d7f236e8f647b
--- /dev/null
+++ b/mmpose/datasets/datasets/fashion/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .deepfashion2_dataset import DeepFashion2Dataset
+from .deepfashion_dataset import DeepFashionDataset
+
+__all__ = ['DeepFashionDataset', 'DeepFashion2Dataset']
diff --git a/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..023eb00097f5274f331bfd4571e51264af1edb4a
Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e773929a8a21d85be41798b06303fa59ffd5580c
Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion2_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b90b5436aca0c00ed0872ab0f9e27c0f00d99b5
Binary files /dev/null and b/mmpose/datasets/datasets/fashion/__pycache__/deepfashion_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py b/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3cde9bf97be254927aa6a06f46bdcc225f14283
--- /dev/null
+++ b/mmpose/datasets/datasets/fashion/deepfashion2_dataset.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module(name='DeepFashion2Dataset')
+class DeepFashion2Dataset(BaseCocoStyleDataset):
+ """DeepFashion2 dataset for fashion landmark detection."""
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/deepfashion2.py')
diff --git a/mmpose/datasets/datasets/fashion/deepfashion_dataset.py b/mmpose/datasets/datasets/fashion/deepfashion_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0aa4937323e41333d48a82a11862e68ffc697f0
--- /dev/null
+++ b/mmpose/datasets/datasets/fashion/deepfashion_dataset.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, List, Optional, Sequence, Union
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class DeepFashionDataset(BaseCocoStyleDataset):
+ """DeepFashion dataset (full-body clothes) for fashion landmark detection.
+
+ "DeepFashion: Powering Robust Clothes Recognition
+ and Retrieval with Rich Annotations", CVPR'2016.
+ "Fashion Landmark Detection in the Wild", ECCV'2016.
+
+ The dataset contains 3 categories for full-body, upper-body and lower-body.
+
+ Fashion landmark indexes for upper-body clothes::
+
+ 0: 'left collar',
+ 1: 'right collar',
+ 2: 'left sleeve',
+ 3: 'right sleeve',
+ 4: 'left hem',
+ 5: 'right hem'
+
+ Fashion landmark indexes for lower-body clothes::
+
+ 0: 'left waistline',
+ 1: 'right waistline',
+ 2: 'left hem',
+ 3: 'right hem'
+
+ Fashion landmark indexes for full-body clothes::
+
+ 0: 'left collar',
+ 1: 'right collar',
+ 2: 'left sleeve',
+ 3: 'right sleeve',
+ 4: 'left waistline',
+ 5: 'right waistline',
+ 6: 'left hem',
+ 7: 'right hem'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ subset (str): Specifies the subset of body: ``'full'``, ``'upper'`` or
+ ``'lower'``. Default: '', which means ``'full'``.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img='')``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ def __init__(self,
+ ann_file: str = '',
+ subset: str = '',
+ bbox_file: Optional[str] = None,
+ data_mode: str = 'topdown',
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: dict = dict(img=''),
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: List[Union[dict, Callable]] = [],
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000):
+ self._check_subset_and_metainfo(subset)
+
+ super().__init__(
+ ann_file=ann_file,
+ bbox_file=bbox_file,
+ data_mode=data_mode,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ filter_cfg=filter_cfg,
+ indices=indices,
+ serialize_data=serialize_data,
+ pipeline=pipeline,
+ test_mode=test_mode,
+ lazy_init=lazy_init,
+ max_refetch=max_refetch)
+
+ @classmethod
+ def _check_subset_and_metainfo(cls, subset: str = '') -> None:
+ """Check the subset of body and set the corresponding metainfo.
+
+ Args:
+ subset(str): the subset of body: could be ``'full'``, ``'upper'``
+ or ``'lower'``. Default: '', which means ``'full'``.
+ """
+ if subset == '' or subset == 'full':
+ cls.METAINFO = dict(
+ from_file='configs/_base_/datasets/deepfashion_full.py')
+ elif subset == 'upper':
+ cls.METAINFO = dict(
+ from_file='configs/_base_/datasets/deepfashion_upper.py')
+ elif subset == 'lower':
+ cls.METAINFO = dict(
+ from_file='configs/_base_/datasets/deepfashion_lower.py')
+ else:
+ raise ValueError(
+ f'{cls.__class__.__name__} got invalid subset: '
+ f'{subset}. Should be "full", "lower" or "upper".')
diff --git a/mmpose/datasets/datasets/hand/__init__.py b/mmpose/datasets/datasets/hand/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5e2222be9e981c51b38ac879b534966ec6aa861
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .coco_wholebody_hand_dataset import CocoWholeBodyHandDataset
+from .freihand_dataset import FreiHandDataset
+from .onehand10k_dataset import OneHand10KDataset
+from .panoptic_hand2d_dataset import PanopticHand2DDataset
+from .rhd2d_dataset import Rhd2DDataset
+
+__all__ = [
+ 'OneHand10KDataset', 'FreiHandDataset', 'PanopticHand2DDataset',
+ 'Rhd2DDataset', 'CocoWholeBodyHandDataset'
+]
diff --git a/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4996f11b284444f4f56f9c59626bd57135355a37
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d99b789df71e7d2ab77040055c0c15458ec2d16
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/coco_wholebody_hand_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ab8c55ae6b9e45ee4763bb84bc3a45b9f7c02f5
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/freihand_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..181339aa6f8122d17590f93c3d016fd44a71966f
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/onehand10k_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55a5d3d52cf5d03826d21baf37887eaf0145622a
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/panoptic_hand2d_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc b/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0602b0335da5d433cdfe49b6be1fa13ab689a5df
Binary files /dev/null and b/mmpose/datasets/datasets/hand/__pycache__/rhd2d_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py b/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dba0132f584f812c323ea040675931a14ca8841c
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/coco_wholebody_hand_dataset.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import List, Tuple
+
+import numpy as np
+from mmengine.fileio import exists, get_local_path
+from xtcocotools.coco import COCO
+
+from mmpose.registry import DATASETS
+from mmpose.structures.bbox import bbox_xywh2xyxy
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class CocoWholeBodyHandDataset(BaseCocoStyleDataset):
+ """CocoWholeBodyDataset for hand pose estimation.
+
+ "Whole-Body Human Pose Estimation in the Wild", ECCV'2020.
+ More details can be found in the `paper
+ `__ .
+
+ COCO-WholeBody Hand keypoints::
+
+ 0: 'wrist',
+ 1: 'thumb1',
+ 2: 'thumb2',
+ 3: 'thumb3',
+ 4: 'thumb4',
+ 5: 'forefinger1',
+ 6: 'forefinger2',
+ 7: 'forefinger3',
+ 8: 'forefinger4',
+ 9: 'middle_finger1',
+ 10: 'middle_finger2',
+ 11: 'middle_finger3',
+ 12: 'middle_finger4',
+ 13: 'ring_finger1',
+ 14: 'ring_finger2',
+ 15: 'ring_finger3',
+ 16: 'ring_finger4',
+ 17: 'pinky_finger1',
+ 18: 'pinky_finger2',
+ 19: 'pinky_finger3',
+ 20: 'pinky_finger4'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(
+ from_file='configs/_base_/datasets/coco_wholebody_hand.py')
+
+ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
+ """Load data from annotations in COCO format."""
+
+ assert exists(self.ann_file), 'Annotation file does not exist'
+
+ with get_local_path(self.ann_file) as local_path:
+ self.coco = COCO(local_path)
+ instance_list = []
+ image_list = []
+ id = 0
+
+ for img_id in self.coco.getImgIds():
+ img = self.coco.loadImgs(img_id)[0]
+
+ img.update({
+ 'img_id':
+ img_id,
+ 'img_path':
+ osp.join(self.data_prefix['img'], img['file_name']),
+ })
+ image_list.append(img)
+
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
+ anns = self.coco.loadAnns(ann_ids)
+ for ann in anns:
+ for type in ['left', 'right']:
+ # filter invalid hand annotations, there might be two
+ # valid instances (left and right hand) in one image
+ if ann[f'{type}hand_valid'] and max(
+ ann[f'{type}hand_kpts']) > 0:
+
+ bbox_xywh = np.array(
+ ann[f'{type}hand_box'],
+ dtype=np.float32).reshape(1, 4)
+
+ bbox = bbox_xywh2xyxy(bbox_xywh)
+
+ _keypoints = np.array(
+ ann[f'{type}hand_kpts'],
+ dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ instance_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img['img_path'],
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'segmentation': ann['segmentation'],
+ 'id': id,
+ }
+ instance_list.append(instance_info)
+ id = id + 1
+
+ instance_list = sorted(instance_list, key=lambda x: x['id'])
+ return instance_list, image_list
diff --git a/mmpose/datasets/datasets/hand/freihand_dataset.py b/mmpose/datasets/datasets/hand/freihand_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f0e23cdd577d12e6d20656fde59f7da58a45150
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/freihand_dataset.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class FreiHandDataset(BaseCocoStyleDataset):
+ """FreiHand dataset for hand pose estimation.
+
+ "FreiHAND: A Dataset for Markerless Capture of Hand Pose
+ and Shape from Single RGB Images", ICCV'2019.
+ More details can be found in the `paper
+ `__ .
+
+ FreiHand keypoints::
+
+ 0: 'wrist',
+ 1: 'thumb1',
+ 2: 'thumb2',
+ 3: 'thumb3',
+ 4: 'thumb4',
+ 5: 'forefinger1',
+ 6: 'forefinger2',
+ 7: 'forefinger3',
+ 8: 'forefinger4',
+ 9: 'middle_finger1',
+ 10: 'middle_finger2',
+ 11: 'middle_finger3',
+ 12: 'middle_finger4',
+ 13: 'ring_finger1',
+ 14: 'ring_finger2',
+ 15: 'ring_finger3',
+ 16: 'ring_finger4',
+ 17: 'pinky_finger1',
+ 18: 'pinky_finger2',
+ 19: 'pinky_finger3',
+ 20: 'pinky_finger4'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/freihand2d.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+
+ # use the entire image which is 224x224
+ bbox = np.array([0, 0, 224, 224], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'segmentation': ann['segmentation'],
+ 'id': ann['id'],
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/hand/onehand10k_dataset.py b/mmpose/datasets/datasets/hand/onehand10k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3519ace560ef70ce680955bfa82d52c1a11b6b3e
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/onehand10k_dataset.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class OneHand10KDataset(BaseCocoStyleDataset):
+ """OneHand10K dataset for hand pose estimation.
+
+ "Mask-pose Cascaded CNN for 2D Hand Pose Estimation from
+ Single Color Images", TCSVT'2019.
+ More details can be found in the `paper
+ `__ .
+
+ OneHand10K keypoints::
+
+ 0: 'wrist',
+ 1: 'thumb1',
+ 2: 'thumb2',
+ 3: 'thumb3',
+ 4: 'thumb4',
+ 5: 'forefinger1',
+ 6: 'forefinger2',
+ 7: 'forefinger3',
+ 8: 'forefinger4',
+ 9: 'middle_finger1',
+ 10: 'middle_finger2',
+ 11: 'middle_finger3',
+ 12: 'middle_finger4',
+ 13: 'ring_finger1',
+ 14: 'ring_finger2',
+ 15: 'ring_finger3',
+ 16: 'ring_finger4',
+ 17: 'pinky_finger1',
+ 18: 'pinky_finger2',
+ 19: 'pinky_finger3',
+ 20: 'pinky_finger4'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/onehand10k.py')
diff --git a/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py b/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d364840ebe5756687a72a4de52b0213ffdcea2
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/panoptic_hand2d_dataset.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class PanopticHand2DDataset(BaseCocoStyleDataset):
+ """Panoptic 2D dataset for hand pose estimation.
+
+ "Hand Keypoint Detection in Single Images using Multiview
+ Bootstrapping", CVPR'2017.
+ More details can be found in the `paper
+ `__ .
+
+ Panoptic keypoints::
+
+ 0: 'wrist',
+ 1: 'thumb1',
+ 2: 'thumb2',
+ 3: 'thumb3',
+ 4: 'thumb4',
+ 5: 'forefinger1',
+ 6: 'forefinger2',
+ 7: 'forefinger3',
+ 8: 'forefinger4',
+ 9: 'middle_finger1',
+ 10: 'middle_finger2',
+ 11: 'middle_finger3',
+ 12: 'middle_finger4',
+ 13: 'ring_finger1',
+ 14: 'ring_finger2',
+ 15: 'ring_finger3',
+ 16: 'ring_finger4',
+ 17: 'pinky_finger1',
+ 18: 'pinky_finger2',
+ 19: 'pinky_finger3',
+ 20: 'pinky_finger4'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(
+ from_file='configs/_base_/datasets/panoptic_hand2d.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'segmentation': ann['segmentation'],
+ 'head_size': ann['head_size'],
+ 'id': ann['id'],
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/hand/rhd2d_dataset.py b/mmpose/datasets/datasets/hand/rhd2d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc4301590a5f8c8d474b0ef37de4d03309ad0b9
--- /dev/null
+++ b/mmpose/datasets/datasets/hand/rhd2d_dataset.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class Rhd2DDataset(BaseCocoStyleDataset):
+ """Rendered Handpose Dataset for hand pose estimation.
+
+ "Learning to Estimate 3D Hand Pose from Single RGB Images",
+ ICCV'2017.
+ More details can be found in the `paper
+ `__ .
+
+ Rhd keypoints::
+
+ 0: 'wrist',
+ 1: 'thumb4',
+ 2: 'thumb3',
+ 3: 'thumb2',
+ 4: 'thumb1',
+ 5: 'forefinger4',
+ 6: 'forefinger3',
+ 7: 'forefinger2',
+ 8: 'forefinger1',
+ 9: 'middle_finger4',
+ 10: 'middle_finger3',
+ 11: 'middle_finger2',
+ 12: 'middle_finger1',
+ 13: 'ring_finger4',
+ 14: 'ring_finger3',
+ 15: 'ring_finger2',
+ 16: 'ring_finger1',
+ 17: 'pinky_finger4',
+ 18: 'pinky_finger3',
+ 19: 'pinky_finger2',
+ 20: 'pinky_finger1'
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/rhd2d.py')
diff --git a/mmpose/datasets/datasets/utils.py b/mmpose/datasets/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5140126163d49eb92a25fe0db1d5d4ebb7144fa6
--- /dev/null
+++ b/mmpose/datasets/datasets/utils.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+import numpy as np
+from mmengine import Config
+
+
+def parse_pose_metainfo(metainfo: dict):
+ """Load meta information of pose dataset and check its integrity.
+
+ Args:
+ metainfo (dict): Raw data of pose meta information, which should
+ contain following contents:
+
+ - "dataset_name" (str): The name of the dataset
+ - "keypoint_info" (dict): The keypoint-related meta information,
+ e.g., name, upper/lower body, and symmetry
+ - "skeleton_info" (dict): The skeleton-related meta information,
+ e.g., start/end keypoint of limbs
+ - "joint_weights" (list[float]): The loss weights of keypoints
+ - "sigmas" (list[float]): The keypoint distribution parameters
+ to calculate OKS score. See `COCO keypoint evaluation
+ `__.
+
+ An example of metainfo is shown as follows.
+
+ .. code-block:: none
+ {
+ "dataset_name": "coco",
+ "keypoint_info":
+ {
+ 0:
+ {
+ "name": "nose",
+ "type": "upper",
+ "swap": "",
+ "color": [51, 153, 255],
+ },
+ 1:
+ {
+ "name": "right_eye",
+ "type": "upper",
+ "swap": "left_eye",
+ "color": [51, 153, 255],
+ },
+ ...
+ },
+ "skeleton_info":
+ {
+ 0:
+ {
+ "link": ("left_ankle", "left_knee"),
+ "color": [0, 255, 0],
+ },
+ ...
+ },
+ "joint_weights": [1., 1., ...],
+ "sigmas": [0.026, 0.025, ...],
+ }
+
+
+ A special case is that `metainfo` can have the key "from_file",
+ which should be the path of a config file. In this case, the
+ actual metainfo will be loaded by:
+
+ .. code-block:: python
+ metainfo = mmengine.Config.fromfile(metainfo['from_file'])
+
+ Returns:
+ Dict: pose meta information that contains following contents:
+
+ - "dataset_name" (str): Same as ``"dataset_name"`` in the input
+ - "num_keypoints" (int): Number of keypoints
+ - "keypoint_id2name" (dict): Mapping from keypoint id to name
+ - "keypoint_name2id" (dict): Mapping from keypoint name to id
+ - "upper_body_ids" (list): Ids of upper-body keypoint
+ - "lower_body_ids" (list): Ids of lower-body keypoint
+ - "flip_indices" (list): The Id of each keypoint's symmetric keypoint
+ - "flip_pairs" (list): The Ids of symmetric keypoint pairs
+ - "keypoint_colors" (numpy.ndarray): The keypoint color matrix of
+ shape [K, 3], where each row is the color of one keypint in bgr
+ - "num_skeleton_links" (int): The number of links
+ - "skeleton_links" (list): The links represented by Id pairs of start
+ and end points
+ - "skeleton_link_colors" (numpy.ndarray): The link color matrix
+ - "dataset_keypoint_weights" (numpy.ndarray): Same as the
+ ``"joint_weights"`` in the input
+ - "sigmas" (numpy.ndarray): Same as the ``"sigmas"`` in the input
+ """
+
+ if 'from_file' in metainfo:
+ cfg_file = metainfo['from_file']
+ if not osp.isfile(cfg_file):
+ # Search configs in 'mmpose/.mim/configs/' in case that mmpose
+ # is installed in non-editable mode.
+ import mmpose
+ mmpose_path = osp.dirname(mmpose.__file__)
+ _cfg_file = osp.join(mmpose_path, '.mim', 'configs', '_base_',
+ 'datasets', osp.basename(cfg_file))
+ if osp.isfile(_cfg_file):
+ warnings.warn(
+ f'The metainfo config file "{cfg_file}" does not exist. '
+ f'A matched config file "{_cfg_file}" will be used '
+ 'instead.')
+ cfg_file = _cfg_file
+ else:
+ raise FileNotFoundError(
+ f'The metainfo config file "{cfg_file}" does not exist.')
+
+ # TODO: remove the nested structure of dataset_info
+ # metainfo = Config.fromfile(metainfo['from_file'])
+ metainfo = Config.fromfile(cfg_file).dataset_info
+
+ # check data integrity
+ assert 'dataset_name' in metainfo
+ assert 'keypoint_info' in metainfo
+ assert 'skeleton_info' in metainfo
+ assert 'joint_weights' in metainfo
+ assert 'sigmas' in metainfo
+
+ # parse metainfo
+ parsed = dict(
+ dataset_name=None,
+ num_keypoints=None,
+ keypoint_id2name={},
+ keypoint_name2id={},
+ upper_body_ids=[],
+ lower_body_ids=[],
+ flip_indices=[],
+ flip_pairs=[],
+ keypoint_colors=[],
+ num_skeleton_links=None,
+ skeleton_links=[],
+ skeleton_link_colors=[],
+ dataset_keypoint_weights=None,
+ sigmas=None,
+ )
+
+ parsed['dataset_name'] = metainfo['dataset_name']
+
+ # parse keypoint information
+ parsed['num_keypoints'] = len(metainfo['keypoint_info'])
+
+ for kpt_id, kpt in metainfo['keypoint_info'].items():
+ kpt_name = kpt['name']
+ parsed['keypoint_id2name'][kpt_id] = kpt_name
+ parsed['keypoint_name2id'][kpt_name] = kpt_id
+ parsed['keypoint_colors'].append(kpt.get('color', [255, 128, 0]))
+
+ kpt_type = kpt.get('type', '')
+ if kpt_type == 'upper':
+ parsed['upper_body_ids'].append(kpt_id)
+ elif kpt_type == 'lower':
+ parsed['lower_body_ids'].append(kpt_id)
+
+ swap_kpt = kpt.get('swap', '')
+ if swap_kpt == kpt_name or swap_kpt == '':
+ parsed['flip_indices'].append(kpt_name)
+ else:
+ parsed['flip_indices'].append(swap_kpt)
+ pair = (swap_kpt, kpt_name)
+ if pair not in parsed['flip_pairs']:
+ parsed['flip_pairs'].append(pair)
+
+ # parse skeleton information
+ parsed['num_skeleton_links'] = len(metainfo['skeleton_info'])
+ for _, sk in metainfo['skeleton_info'].items():
+ parsed['skeleton_links'].append(sk['link'])
+ parsed['skeleton_link_colors'].append(sk.get('color', [96, 96, 255]))
+
+ # parse extra information
+ parsed['dataset_keypoint_weights'] = np.array(
+ metainfo['joint_weights'], dtype=np.float32)
+ parsed['sigmas'] = np.array(metainfo['sigmas'], dtype=np.float32)
+
+ # formatting
+ def _map(src, mapping: dict):
+ if isinstance(src, (list, tuple)):
+ cls = type(src)
+ return cls(_map(s, mapping) for s in src)
+ else:
+ return mapping[src]
+
+ parsed['flip_pairs'] = _map(
+ parsed['flip_pairs'], mapping=parsed['keypoint_name2id'])
+ parsed['flip_indices'] = _map(
+ parsed['flip_indices'], mapping=parsed['keypoint_name2id'])
+ parsed['skeleton_links'] = _map(
+ parsed['skeleton_links'], mapping=parsed['keypoint_name2id'])
+
+ parsed['keypoint_colors'] = np.array(
+ parsed['keypoint_colors'], dtype=np.uint8)
+ parsed['skeleton_link_colors'] = np.array(
+ parsed['skeleton_link_colors'], dtype=np.uint8)
+
+ return parsed
diff --git a/mmpose/datasets/datasets/wholebody/__init__.py b/mmpose/datasets/datasets/wholebody/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..156094c2b03b297ac5e1ad839cc970fc67a0ecf4
--- /dev/null
+++ b/mmpose/datasets/datasets/wholebody/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .coco_wholebody_dataset import CocoWholeBodyDataset
+from .halpe_dataset import HalpeDataset
+
+__all__ = ['CocoWholeBodyDataset', 'HalpeDataset']
diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..931b681247e546fbc46f4a343f10edc4770d798d
Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6b2963204ad9ba6d4c41864820fd39d6f0fdb59
Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/coco_wholebody_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc b/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..571e0a24eb11996a7c9c8891cc7672f5c1f54c64
Binary files /dev/null and b/mmpose/datasets/datasets/wholebody/__pycache__/halpe_dataset.cpython-38.pyc differ
diff --git a/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..00a2ea418f95c1973881072256c455e3de9f2046
--- /dev/null
+++ b/mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class CocoWholeBodyDataset(BaseCocoStyleDataset):
+ """CocoWholeBody dataset for pose estimation.
+
+ "Whole-Body Human Pose Estimation in the Wild", ECCV'2020.
+ More details can be found in the `paper
+ `__ .
+
+ COCO-WholeBody keypoints::
+
+ 0-16: 17 body keypoints,
+ 17-22: 6 foot keypoints,
+ 23-90: 68 face keypoints,
+ 91-132: 42 hand keypoints
+
+ In total, we have 133 keypoints for wholebody pose estimation.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(
+ from_file='configs/_base_/datasets/coco_wholebody.py')
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance.
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ img_path = osp.join(self.data_prefix['img'], img['file_name'])
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ # COCO-Wholebody: consisting of body, foot, face and hand keypoints
+ _keypoints = np.array(ann['keypoints'] + ann['foot_kpts'] +
+ ann['face_kpts'] + ann['lefthand_kpts'] +
+ ann['righthand_kpts']).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2] > 0)
+
+ num_keypoints = ann['num_keypoints']
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img_path,
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann['iscrowd'],
+ 'segmentation': ann['segmentation'],
+ 'id': ann['id'],
+ 'category_id': ann['category_id'],
+ # store the raw annotation of the instance
+ # it is useful for evaluation without providing ann_file
+ 'raw_ann_info': copy.deepcopy(ann),
+ }
+
+ return data_info
diff --git a/mmpose/datasets/datasets/wholebody/halpe_dataset.py b/mmpose/datasets/datasets/wholebody/halpe_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0699f3b7023b200ee42e3cfe7f475a51123ef190
--- /dev/null
+++ b/mmpose/datasets/datasets/wholebody/halpe_dataset.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.registry import DATASETS
+from ..base import BaseCocoStyleDataset
+
+
+@DATASETS.register_module()
+class HalpeDataset(BaseCocoStyleDataset):
+ """Halpe dataset for pose estimation.
+
+ 'https://github.com/Fang-Haoshu/Halpe-FullBody'
+
+ Halpe keypoints::
+
+ 0-19: 20 body keypoints,
+ 20-25: 6 foot keypoints,
+ 26-93: 68 face keypoints,
+ 94-135: 42 hand keypoints
+
+ In total, we have 136 keypoints for wholebody pose estimation.
+
+ Args:
+ ann_file (str): Annotation file path. Default: ''.
+ bbox_file (str, optional): Detection result file path. If
+ ``bbox_file`` is set, detected bboxes loaded from this file will
+ be used instead of ground-truth bboxes. This setting is only for
+ evaluation, i.e., ignored when ``test_mode`` is ``False``.
+ Default: ``None``.
+ data_mode (str): Specifies the mode of data samples: ``'topdown'`` or
+ ``'bottomup'``. In ``'topdown'`` mode, each data sample contains
+ one instance; while in ``'bottomup'`` mode, each data sample
+ contains all instances in a image. Default: ``'topdown'``
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Default: ``None``.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Default: ``None``.
+ data_prefix (dict, optional): Prefix for training data. Default:
+ ``dict(img=None, ann=None)``.
+ filter_cfg (dict, optional): Config for filter data. Default: `None`.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Default: ``None`` which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy.
+ Default: ``True``.
+ pipeline (list, optional): Processing pipeline. Default: [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Default: ``False``.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Default: ``False``.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Default: 1000.
+ """
+
+ METAINFO: dict = dict(from_file='configs/_base_/datasets/halpe.py')
diff --git a/mmpose/datasets/samplers.py b/mmpose/datasets/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6bb34287a8c6b43552601eeb9b2e7c9a4fa90df
--- /dev/null
+++ b/mmpose/datasets/samplers.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+import math
+from typing import Iterator, List, Optional, Sized, Union
+
+import torch
+from mmengine.dist import get_dist_info, sync_random_seed
+from torch.utils.data import Sampler
+
+from mmpose.datasets import CombinedDataset
+from mmpose.registry import DATA_SAMPLERS
+
+
+@DATA_SAMPLERS.register_module()
+class MultiSourceSampler(Sampler):
+ """Multi-Source Sampler. According to the sampling ratio, sample data from
+ different datasets to form batches.
+
+ Args:
+ dataset (Sized): The dataset
+ batch_size (int): Size of mini-batch
+ source_ratio (list[int | float]): The sampling ratio of different
+ source datasets in a mini-batch
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to
+ ``True``
+ round_up (bool): Whether to add extra samples to make the number of
+ samples evenly divisible by the world size. Defaults to True.
+ seed (int, optional): Random seed. If ``None``, set a random seed.
+ Defaults to ``None``
+ """
+
+ def __init__(self,
+ dataset: Sized,
+ batch_size: int,
+ source_ratio: List[Union[int, float]],
+ shuffle: bool = True,
+ round_up: bool = True,
+ seed: Optional[int] = None) -> None:
+
+ assert isinstance(dataset, CombinedDataset),\
+ f'The dataset must be CombinedDataset, but get {dataset}'
+ assert isinstance(batch_size, int) and batch_size > 0, \
+ 'batch_size must be a positive integer value, ' \
+ f'but got batch_size={batch_size}'
+ assert isinstance(source_ratio, list), \
+ f'source_ratio must be a list, but got source_ratio={source_ratio}'
+ assert len(source_ratio) == len(dataset._lens), \
+ 'The length of source_ratio must be equal to ' \
+ f'the number of datasets, but got source_ratio={source_ratio}'
+
+ rank, world_size = get_dist_info()
+ self.rank = rank
+ self.world_size = world_size
+
+ self.dataset = dataset
+ self.cumulative_sizes = [0] + list(itertools.accumulate(dataset._lens))
+ self.batch_size = batch_size
+ self.source_ratio = source_ratio
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / world_size))
+ self.num_per_source = [
+ int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
+ ]
+ self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
+
+ assert sum(self.num_per_source) == batch_size, \
+ 'The sum of num_per_source must be equal to ' \
+ f'batch_size, but get {self.num_per_source}'
+
+ self.seed = sync_random_seed() if seed is None else seed
+ self.shuffle = shuffle
+ self.round_up = round_up
+ self.source2inds = {
+ source: self._indices_of_rank(len(ds))
+ for source, ds in enumerate(dataset.datasets)
+ }
+
+ def _infinite_indices(self, sample_size: int) -> Iterator[int]:
+ """Infinitely yield a sequence of indices."""
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.randperm(sample_size, generator=g).tolist()
+ else:
+ yield from torch.arange(sample_size).tolist()
+
+ def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
+ """Slice the infinite indices by rank."""
+ yield from itertools.islice(
+ self._infinite_indices(sample_size), self.rank, None,
+ self.world_size)
+
+ def __iter__(self) -> Iterator[int]:
+ batch_buffer = []
+ num_iters = self.num_samples // self.batch_size
+ if self.round_up and self.num_samples > num_iters * self.batch_size:
+ num_iters += 1
+ for i in range(num_iters):
+ for source, num in enumerate(self.num_per_source):
+ batch_buffer_per_source = []
+ for idx in self.source2inds[source]:
+ idx += self.cumulative_sizes[source]
+ batch_buffer_per_source.append(idx)
+ if len(batch_buffer_per_source) == num:
+ batch_buffer += batch_buffer_per_source
+ break
+ return iter(batch_buffer)
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+ def set_epoch(self, epoch: int) -> None:
+ """Compatible in `epoch-based runner."""
+ pass
diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61dae74b8c6e756a87151c2ef8ae2423eb507515
--- /dev/null
+++ b/mmpose/datasets/transforms/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bottomup_transforms import (BottomupGetHeatmapMask, BottomupRandomAffine,
+ BottomupResize)
+from .common_transforms import (Albumentation, GenerateTarget,
+ GetBBoxCenterScale, PhotometricDistortion,
+ RandomBBoxTransform, RandomFlip,
+ RandomHalfBody)
+from .converting import KeypointConverter
+from .formatting import PackPoseInputs
+from .loading import LoadImage
+from .topdown_transforms import TopdownAffine
+
+__all__ = [
+ 'GetBBoxCenterScale', 'RandomBBoxTransform', 'RandomFlip',
+ 'RandomHalfBody', 'TopdownAffine', 'Albumentation',
+ 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage',
+ 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize',
+ 'GenerateTarget', 'KeypointConverter'
+]
diff --git a/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25f01623b630515298da7ef1e525c23d7d40927b
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76ffe18189f729a5019a1a6f143e7df5387f68d0
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/bottomup_transforms.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ccf52a667cbdc6f3e958b5dfecf0bb3c225c342
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/common_transforms.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bc65203f98e4a7cfe327d59bfd9c08a3dea83c2
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/converting.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bf2b35e6a5b2e170dccbed1983248a441e4efa4
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/formatting.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3368dc2771aafe618300b2ea832ffb6206eaa8f0
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/loading.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc b/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91d214bbd6c54fc3192f67f4756cada2ad338af5
Binary files /dev/null and b/mmpose/datasets/transforms/__pycache__/topdown_transforms.cpython-38.pyc differ
diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c31e0ae17d65f074a89c56210285184f2eeedc0b
--- /dev/null
+++ b/mmpose/datasets/transforms/bottomup_transforms.py
@@ -0,0 +1,517 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+import xtcocotools.mask as cocomask
+from mmcv.image import imflip_, imresize
+from mmcv.transforms import BaseTransform
+from mmcv.transforms.utils import cache_randomness
+from scipy.stats import truncnorm
+
+from mmpose.registry import TRANSFORMS
+from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix
+
+
+@TRANSFORMS.register_module()
+class BottomupGetHeatmapMask(BaseTransform):
+ """Generate the mask of valid regions from the segmentation annotation.
+
+ Required Keys:
+
+ - img_shape
+ - invalid_segs (optional)
+ - warp_mat (optional)
+ - flip (optional)
+ - flip_direction (optional)
+ - heatmaps (optional)
+
+ Added Keys:
+
+ - heatmap_mask
+ """
+
+ def _segs_to_mask(self, segs: list, img_shape: Tuple[int,
+ int]) -> np.ndarray:
+ """Calculate mask from object segmentations.
+
+ Args:
+ segs (List): The object segmentation annotations in COCO format
+ img_shape (Tuple): The image shape in (h, w)
+
+ Returns:
+ np.ndarray: The binary object mask in size (h, w), where the
+ object pixels are 1 and background pixels are 0
+ """
+
+ # RLE is a simple yet efficient format for storing binary masks.
+ # details can be found at `COCO tools `__
+ rles = []
+ for seg in segs:
+ rle = cocomask.frPyObjects(seg, img_shape[0], img_shape[1])
+ if isinstance(rle, list):
+ # For non-crowded objects (e.g. human with no visible
+ # keypoints), the results is a list of rles
+ rles.extend(rle)
+ else:
+ # For crowded objects, the result is a single rle
+ rles.append(rle)
+
+ if rles:
+ mask = cocomask.decode(cocomask.merge(rles))
+ else:
+ mask = np.zeros(img_shape, dtype=np.uint8)
+
+ return mask
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`BottomupGetHeatmapMask` to perform
+ photometric distortion on images.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ invalid_segs = results.get('invalid_segs', [])
+ img_shape = results['img_shape'] # (img_h, img_w)
+ input_size = results['input_size']
+
+ # Calculate the mask of the valid region by negating the segmentation
+ # mask of invalid objects
+ mask = 1 - self._segs_to_mask(invalid_segs, img_shape)
+
+ # Apply an affine transform to the mask if the image has been
+ # transformed
+ if 'warp_mat' in results:
+ warp_mat = results['warp_mat']
+
+ mask = mask.astype(np.float32)
+ mask = cv2.warpAffine(
+ mask, warp_mat, input_size, flags=cv2.INTER_LINEAR)
+
+ # Flip the mask if the image has been flipped
+ if results.get('flip', False):
+ flip_dir = results['flip_direction']
+ if flip_dir is not None:
+ mask = imflip_(mask, flip_dir)
+
+ # Resize the mask to the same size of heatmaps
+ if 'heatmaps' in results:
+ heatmaps = results['heatmaps']
+ if isinstance(heatmaps, list):
+ # Multi-level heatmaps
+ heatmap_mask = []
+ for hm in results['heatmaps']:
+ h, w = hm.shape[1:3]
+ _mask = imresize(
+ mask, size=(w, h), interpolation='bilinear')
+ heatmap_mask.append(_mask)
+ else:
+ h, w = heatmaps.shape[1:3]
+ heatmap_mask = imresize(
+ mask, size=(w, h), interpolation='bilinear')
+ else:
+ heatmap_mask = mask
+
+ # Binarize the mask(s)
+ if isinstance(heatmap_mask, list):
+ results['heatmap_mask'] = [hm > 0.5 for hm in heatmap_mask]
+ else:
+ results['heatmap_mask'] = heatmap_mask > 0.5
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class BottomupRandomAffine(BaseTransform):
+ r"""Randomly shift, resize and rotate the image.
+
+ Required Keys:
+
+ - img
+ - img_shape
+ - keypoints (optional)
+
+ Modified Keys:
+
+ - img
+ - keypoints (optional)
+
+ Added Keys:
+
+ - input_size
+ - warp_mat
+
+ Args:
+ input_size (Tuple[int, int]): The input image size of the model in
+ [w, h]
+ shift_factor (float): Randomly shift the image in range
+ :math:`[-dx, dx]` and :math:`[-dy, dy]` in X and Y directions,
+ where :math:`dx(y) = img_w(h) \cdot shift_factor` in pixels.
+ Defaults to 0.2
+ shift_prob (float): Probability of applying random shift. Defaults to
+ 1.0
+ scale_factor (Tuple[float, float]): Randomly resize the image in range
+ :math:`[scale_factor[0], scale_factor[1]]`. Defaults to
+ (0.75, 1.5)
+ scale_prob (float): Probability of applying random resizing. Defaults
+ to 1.0
+ scale_type (str): wrt ``long`` or ``short`` length of the image.
+ Defaults to ``short``
+ rotate_factor (float): Randomly rotate the bbox in
+ :math:`[-rotate_factor, rotate_factor]` in degrees. Defaults
+ to 40.0
+ use_udp (bool): Whether use unbiased data processing. See
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
+
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ shift_factor: float = 0.2,
+ shift_prob: float = 1.,
+ scale_factor: Tuple[float, float] = (0.75, 1.5),
+ scale_prob: float = 1.,
+ scale_type: str = 'short',
+ rotate_factor: float = 30.,
+ rotate_prob: float = 1,
+ use_udp: bool = False) -> None:
+ super().__init__()
+
+ self.input_size = input_size
+ self.shift_factor = shift_factor
+ self.shift_prob = shift_prob
+ self.scale_factor = scale_factor
+ self.scale_prob = scale_prob
+ self.scale_type = scale_type
+ self.rotate_factor = rotate_factor
+ self.rotate_prob = rotate_prob
+ self.use_udp = use_udp
+
+ @staticmethod
+ def _truncnorm(low: float = -1.,
+ high: float = 1.,
+ size: tuple = ()) -> np.ndarray:
+ """Sample from a truncated normal distribution."""
+ return truncnorm.rvs(low, high, size=size).astype(np.float32)
+
+ def _fix_aspect_ratio(self, scale: np.ndarray, aspect_ratio: float):
+ """Extend the scale to match the given aspect ratio.
+
+ Args:
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
+ aspect_ratio (float): The ratio of ``w/h``
+
+ Returns:
+ np.ndarray: The reshaped image scale in (2, )
+ """
+ w, h = scale
+ if w > h * aspect_ratio:
+ if self.scale_type == 'long':
+ _w, _h = w, w / aspect_ratio
+ elif self.scale_type == 'short':
+ _w, _h = h * aspect_ratio, h
+ else:
+ raise ValueError(f'Unknown scale type: {self.scale_type}')
+ else:
+ if self.scale_type == 'short':
+ _w, _h = w, w / aspect_ratio
+ elif self.scale_type == 'long':
+ _w, _h = h * aspect_ratio, h
+ else:
+ raise ValueError(f'Unknown scale type: {self.scale_type}')
+ return np.array([_w, _h], dtype=scale.dtype)
+
+ @cache_randomness
+ def _get_transform_params(self) -> Tuple:
+ """Get random transform parameters.
+
+ Returns:
+ tuple:
+ - offset (np.ndarray): Image offset rate in shape (2, )
+ - scale (np.ndarray): Image scaling rate factor in shape (1, )
+ - rotate (np.ndarray): Image rotation degree in shape (1, )
+ """
+ # get offset
+ if np.random.rand() < self.shift_prob:
+ offset = self._truncnorm(size=(2, )) * self.shift_factor
+ else:
+ offset = np.zeros((2, ), dtype=np.float32)
+
+ # get scale
+ if np.random.rand() < self.scale_prob:
+ scale_min, scale_max = self.scale_factor
+ scale = scale_min + (scale_max - scale_min) * (
+ self._truncnorm(size=(1, )) + 1) / 2
+ else:
+ scale = np.ones(1, dtype=np.float32)
+
+ # get rotation
+ if np.random.rand() < self.rotate_prob:
+ rotate = self._truncnorm() * self.rotate_factor
+ else:
+ rotate = 0
+
+ return offset, scale, rotate
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`BottomupRandomAffine` to perform
+ photometric distortion on images.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ img_h, img_w = results['img_shape']
+ w, h = self.input_size
+
+ offset_rate, scale_rate, rotate = self._get_transform_params()
+ offset = offset_rate * [img_w, img_h]
+ scale = scale_rate * [img_w, img_h]
+ # adjust the scale to match the target aspect ratio
+ scale = self._fix_aspect_ratio(scale, aspect_ratio=w / h)
+
+ if self.use_udp:
+ center = np.array([(img_w - 1.0) / 2, (img_h - 1.0) / 2],
+ dtype=np.float32)
+ warp_mat = get_udp_warp_matrix(
+ center=center + offset,
+ scale=scale,
+ rot=rotate,
+ output_size=(w, h))
+ else:
+ center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
+ warp_mat = get_warp_matrix(
+ center=center + offset,
+ scale=scale,
+ rot=rotate,
+ output_size=(w, h))
+
+ # warp image and keypoints
+ results['img'] = cv2.warpAffine(
+ results['img'], warp_mat, (int(w), int(h)), flags=cv2.INTER_LINEAR)
+
+ if 'keypoints' in results:
+ # Only transform (x, y) coordinates
+ results['keypoints'][..., :2] = cv2.transform(
+ results['keypoints'][..., :2], warp_mat)
+
+ if 'bbox' in results:
+ bbox = np.tile(results['bbox'], 2).reshape(-1, 4, 2)
+ # corner order: left_top, left_bottom, right_top, right_bottom
+ bbox[:, 1:3, 0] = bbox[:, 0:2, 0]
+ results['bbox'] = cv2.transform(bbox, warp_mat).reshape(-1, 8)
+
+ results['input_size'] = self.input_size
+ results['warp_mat'] = warp_mat
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class BottomupResize(BaseTransform):
+ """Resize the image to the input size of the model. Optionally, the image
+ can be resized to multiple sizes to build a image pyramid for multi-scale
+ inference.
+
+ Required Keys:
+
+ - img
+ - ori_shape
+
+ Modified Keys:
+
+ - img
+ - img_shape
+
+ Added Keys:
+
+ - input_size
+ - warp_mat
+ - aug_scale
+
+ Args:
+ input_size (Tuple[int, int]): The input size of the model in [w, h].
+ Note that the actually size of the resized image will be affected
+ by ``resize_mode`` and ``size_factor``, thus may not exactly equals
+ to the ``input_size``
+ aug_scales (List[float], optional): The extra input scales for
+ multi-scale testing. If given, the input image will be resized
+ to different scales to build a image pyramid. And heatmaps from
+ all scales will be aggregated to make final prediction. Defaults
+ to ``None``
+ size_factor (int): The actual input size will be ceiled to
+ a multiple of the `size_factor` value at both sides.
+ Defaults to 16
+ resize_mode (str): The method to resize the image to the input size.
+ Options are:
+
+ - ``'fit'``: The image will be resized according to the
+ relatively longer side with the aspect ratio kept. The
+ resized image will entirely fits into the range of the
+ input size
+ - ``'expand'``: The image will be resized according to the
+ relatively shorter side with the aspect ratio kept. The
+ resized image will exceed the given input size at the
+ longer side
+ use_udp (bool): Whether use unbiased data processing. See
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
+
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ aug_scales: Optional[List[float]] = None,
+ size_factor: int = 32,
+ resize_mode: str = 'fit',
+ use_udp: bool = False):
+ super().__init__()
+
+ self.input_size = input_size
+ self.aug_scales = aug_scales
+ self.resize_mode = resize_mode
+ self.size_factor = size_factor
+ self.use_udp = use_udp
+
+ @staticmethod
+ def _ceil_to_multiple(size: Tuple[int, int], base: int):
+ """Ceil the given size (tuple of [w, h]) to a multiple of the base."""
+ return tuple(int(np.ceil(s / base) * base) for s in size)
+
+ def _get_input_size(self, img_size: Tuple[int, int],
+ input_size: Tuple[int, int]) -> Tuple:
+ """Calculate the actual input size (which the original image will be
+ resized to) and the padded input size (which the resized image will be
+ padded to, or which is the size of the model input).
+
+ Args:
+ img_size (Tuple[int, int]): The original image size in [w, h]
+ input_size (Tuple[int, int]): The expected input size in [w, h]
+
+ Returns:
+ tuple:
+ - actual_input_size (Tuple[int, int]): The target size to resize
+ the image
+ - padded_input_size (Tuple[int, int]): The target size to generate
+ the model input which will contain the resized image
+ """
+ img_w, img_h = img_size
+ ratio = img_w / img_h
+
+ if self.resize_mode == 'fit':
+ padded_input_size = self._ceil_to_multiple(input_size,
+ self.size_factor)
+ if padded_input_size != input_size:
+ raise ValueError(
+ 'When ``resize_mode==\'fit\', the input size (height and'
+ ' width) should be mulitples of the size_factor('
+ f'{self.size_factor}) at all scales. Got invalid input '
+ f'size {input_size}.')
+
+ pad_w, pad_h = padded_input_size
+ rsz_w = min(pad_w, pad_h * ratio)
+ rsz_h = min(pad_h, pad_w / ratio)
+ actual_input_size = (rsz_w, rsz_h)
+
+ elif self.resize_mode == 'expand':
+ _padded_input_size = self._ceil_to_multiple(
+ input_size, self.size_factor)
+ pad_w, pad_h = _padded_input_size
+ rsz_w = max(pad_w, pad_h * ratio)
+ rsz_h = max(pad_h, pad_w / ratio)
+
+ actual_input_size = (rsz_w, rsz_h)
+ padded_input_size = self._ceil_to_multiple(actual_input_size,
+ self.size_factor)
+
+ else:
+ raise ValueError(f'Invalid resize mode {self.resize_mode}')
+
+ return actual_input_size, padded_input_size
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`BottomupResize` to perform
+ photometric distortion on images.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ img = results['img']
+ img_h, img_w = results['ori_shape']
+ w, h = self.input_size
+
+ input_sizes = [(w, h)]
+ if self.aug_scales:
+ input_sizes += [(int(w * s), int(h * s)) for s in self.aug_scales]
+
+ imgs = []
+ for i, (_w, _h) in enumerate(input_sizes):
+
+ actual_input_size, padded_input_size = self._get_input_size(
+ img_size=(img_w, img_h), input_size=(_w, _h))
+
+ if self.use_udp:
+ center = np.array([(img_w - 1.0) / 2, (img_h - 1.0) / 2],
+ dtype=np.float32)
+ scale = np.array([img_w, img_h], dtype=np.float32)
+ warp_mat = get_udp_warp_matrix(
+ center=center,
+ scale=scale,
+ rot=0,
+ output_size=actual_input_size)
+ else:
+ center = np.array([img_w / 2, img_h / 2], dtype=np.float32)
+ scale = np.array([
+ img_w * padded_input_size[0] / actual_input_size[0],
+ img_h * padded_input_size[1] / actual_input_size[1]
+ ],
+ dtype=np.float32)
+ warp_mat = get_warp_matrix(
+ center=center,
+ scale=scale,
+ rot=0,
+ output_size=padded_input_size)
+
+ _img = cv2.warpAffine(
+ img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR)
+
+ imgs.append(_img)
+
+ # Store the transform information w.r.t. the main input size
+ if i == 0:
+ results['img_shape'] = padded_input_size[::-1]
+ results['input_center'] = center
+ results['input_scale'] = scale
+ results['input_size'] = padded_input_size
+
+ if self.aug_scales:
+ results['img'] = imgs
+ results['aug_scales'] = self.aug_scales
+ else:
+ results['img'] = imgs[0]
+ results['aug_scale'] = None
+
+ return results
diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8db0ff37c75faec94e6e0df15b9cf2b099db86c9
--- /dev/null
+++ b/mmpose/datasets/transforms/common_transforms.py
@@ -0,0 +1,1044 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from copy import deepcopy
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import mmcv
+import mmengine
+import numpy as np
+from mmcv.image import imflip
+from mmcv.transforms import BaseTransform
+from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness
+from mmengine import is_list_of
+from mmengine.dist import get_dist_info
+from scipy.stats import truncnorm
+
+from mmpose.codecs import * # noqa: F401, F403
+from mmpose.registry import KEYPOINT_CODECS, TRANSFORMS
+from mmpose.structures.bbox import bbox_xyxy2cs, flip_bbox
+from mmpose.structures.keypoint import flip_keypoints
+from mmpose.utils.typing import MultiConfig
+
+try:
+ import albumentations
+except ImportError:
+ albumentations = None
+
+Number = Union[int, float]
+
+
+@TRANSFORMS.register_module()
+class GetBBoxCenterScale(BaseTransform):
+ """Convert bboxes from [x, y, w, h] to center and scale.
+
+ The center is the coordinates of the bbox center, and the scale is the
+ bbox width and height normalized by a scale factor.
+
+ Required Keys:
+
+ - bbox
+
+ Added Keys:
+
+ - bbox_center
+ - bbox_scale
+
+ Args:
+ padding (float): The bbox padding scale that will be multilied to
+ `bbox_scale`. Defaults to 1.25
+ """
+
+ def __init__(self, padding: float = 1.25) -> None:
+ super().__init__()
+
+ self.padding = padding
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`GetBBoxCenterScale`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+ if 'bbox_center' in results and 'bbox_scale' in results:
+ rank, _ = get_dist_info()
+ if rank == 0:
+ warnings.warn('Use the existing "bbox_center" and "bbox_scale"'
+ '. The padding will still be applied.')
+ results['bbox_scale'] *= self.padding
+
+ else:
+ bbox = results['bbox']
+ center, scale = bbox_xyxy2cs(bbox, padding=self.padding)
+
+ results['bbox_center'] = center
+ results['bbox_scale'] = scale
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__ + f'(padding={self.padding})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomFlip(BaseTransform):
+ """Randomly flip the image, bbox and keypoints.
+
+ Required Keys:
+
+ - img
+ - img_shape
+ - flip_indices
+ - input_size (optional)
+ - bbox (optional)
+ - bbox_center (optional)
+ - keypoints (optional)
+ - keypoints_visible (optional)
+ - img_mask (optional)
+
+ Modified Keys:
+
+ - img
+ - bbox (optional)
+ - bbox_center (optional)
+ - keypoints (optional)
+ - keypoints_visible (optional)
+ - img_mask (optional)
+
+ Added Keys:
+
+ - flip
+ - flip_direction
+
+ Args:
+ prob (float | list[float]): The flipping probability. If a list is
+ given, the argument `direction` should be a list with the same
+ length. And each element in `prob` indicates the flipping
+ probability of the corresponding one in ``direction``. Defaults
+ to 0.5
+ direction (str | list[str]): The flipping direction. Options are
+ ``'horizontal'``, ``'vertical'`` and ``'diagonal'``. If a list is
+ is given, each data sample's flipping direction will be sampled
+ from a distribution determined by the argument ``prob``. Defaults
+ to ``'horizontal'``.
+ """
+
+ def __init__(self,
+ prob: Union[float, List[float]] = 0.5,
+ direction: Union[str, List[str]] = 'horizontal') -> None:
+ if isinstance(prob, list):
+ assert is_list_of(prob, float)
+ assert 0 <= sum(prob) <= 1
+ elif isinstance(prob, float):
+ assert 0 <= prob <= 1
+ else:
+ raise ValueError(f'probs must be float or list of float, but \
+ got `{type(prob)}`.')
+ self.prob = prob
+
+ valid_directions = ['horizontal', 'vertical', 'diagonal']
+ if isinstance(direction, str):
+ assert direction in valid_directions
+ elif isinstance(direction, list):
+ assert is_list_of(direction, str)
+ assert set(direction).issubset(set(valid_directions))
+ else:
+ raise ValueError(f'direction must be either str or list of str, \
+ but got `{type(direction)}`.')
+ self.direction = direction
+
+ if isinstance(prob, list):
+ assert len(prob) == len(self.direction)
+
+ @cache_randomness
+ def _choose_direction(self) -> str:
+ """Choose the flip direction according to `prob` and `direction`"""
+ if isinstance(self.direction,
+ List) and not isinstance(self.direction, str):
+ # None means non-flip
+ direction_list: list = list(self.direction) + [None]
+ elif isinstance(self.direction, str):
+ # None means non-flip
+ direction_list = [self.direction, None]
+
+ if isinstance(self.prob, list):
+ non_prob: float = 1 - sum(self.prob)
+ prob_list = self.prob + [non_prob]
+ elif isinstance(self.prob, float):
+ non_prob = 1. - self.prob
+ # exclude non-flip
+ single_ratio = self.prob / (len(direction_list) - 1)
+ prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob]
+
+ cur_dir = np.random.choice(direction_list, p=prob_list)
+
+ return cur_dir
+
+ def transform(self, results: dict) -> dict:
+ """The transform function of :class:`RandomFlip`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+
+ flip_dir = self._choose_direction()
+
+ if flip_dir is None:
+ results['flip'] = False
+ results['flip_direction'] = None
+ else:
+ results['flip'] = True
+ results['flip_direction'] = flip_dir
+
+ h, w = results.get('input_size', results['img_shape'])
+ # flip image and mask
+ if isinstance(results['img'], list):
+ results['img'] = [
+ imflip(img, direction=flip_dir) for img in results['img']
+ ]
+ else:
+ results['img'] = imflip(results['img'], direction=flip_dir)
+
+ if 'img_mask' in results:
+ results['img_mask'] = imflip(
+ results['img_mask'], direction=flip_dir)
+
+ # flip bboxes
+ if results.get('bbox', None) is not None:
+ results['bbox'] = flip_bbox(
+ results['bbox'],
+ image_size=(w, h),
+ bbox_format='xyxy',
+ direction=flip_dir)
+
+ if results.get('bbox_center', None) is not None:
+ results['bbox_center'] = flip_bbox(
+ results['bbox_center'],
+ image_size=(w, h),
+ bbox_format='center',
+ direction=flip_dir)
+
+ # flip keypoints
+ if results.get('keypoints', None) is not None:
+ keypoints, keypoints_visible = flip_keypoints(
+ results['keypoints'],
+ results.get('keypoints_visible', None),
+ image_size=(w, h),
+ flip_indices=results['flip_indices'],
+ direction=flip_dir)
+
+ results['keypoints'] = keypoints
+ results['keypoints_visible'] = keypoints_visible
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob}, '
+ repr_str += f'direction={self.direction})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomHalfBody(BaseTransform):
+ """Data augmentation with half-body transform that keeps only the upper or
+ lower body at random.
+
+ Required Keys:
+
+ - keypoints
+ - keypoints_visible
+ - upper_body_ids
+ - lower_body_ids
+
+ Modified Keys:
+
+ - bbox
+ - bbox_center
+ - bbox_scale
+
+ Args:
+ min_total_keypoints (int): The minimum required number of total valid
+ keypoints of a person to apply half-body transform. Defaults to 8
+ min_half_keypoints (int): The minimum required number of valid
+ half-body keypoints of a person to apply half-body transform.
+ Defaults to 2
+ padding (float): The bbox padding scale that will be multilied to
+ `bbox_scale`. Defaults to 1.5
+ prob (float): The probability to apply half-body transform when the
+ keypoint number meets the requirement. Defaults to 0.3
+ """
+
+ def __init__(self,
+ min_total_keypoints: int = 9,
+ min_upper_keypoints: int = 2,
+ min_lower_keypoints: int = 3,
+ padding: float = 1.5,
+ prob: float = 0.3,
+ upper_prioritized_prob: float = 0.7) -> None:
+ super().__init__()
+ self.min_total_keypoints = min_total_keypoints
+ self.min_upper_keypoints = min_upper_keypoints
+ self.min_lower_keypoints = min_lower_keypoints
+ self.padding = padding
+ self.prob = prob
+ self.upper_prioritized_prob = upper_prioritized_prob
+
+ def _get_half_body_bbox(self, keypoints: np.ndarray,
+ half_body_ids: List[int]
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Get half-body bbox center and scale of a single instance.
+
+ Args:
+ keypoints (np.ndarray): Keypoints in shape (K, D)
+ upper_body_ids (list): The list of half-body keypont indices
+
+ Returns:
+ tuple: A tuple containing half-body bbox center and scale
+ - center: Center (x, y) of the bbox
+ - scale: Scale (w, h) of the bbox
+ """
+
+ selected_keypoints = keypoints[half_body_ids]
+ center = selected_keypoints.mean(axis=0)[:2]
+
+ x1, y1 = selected_keypoints.min(axis=0)
+ x2, y2 = selected_keypoints.max(axis=0)
+ w = x2 - x1
+ h = y2 - y1
+ scale = np.array([w, h], dtype=center.dtype) * self.padding
+
+ return center, scale
+
+ @cache_randomness
+ def _random_select_half_body(self, keypoints_visible: np.ndarray,
+ upper_body_ids: List[int],
+ lower_body_ids: List[int]
+ ) -> List[Optional[List[int]]]:
+ """Randomly determine whether applying half-body transform and get the
+ half-body keyponit indices of each instances.
+
+ Args:
+ keypoints_visible (np.ndarray, optional): The visibility of
+ keypoints in shape (N, K, 1).
+ upper_body_ids (list): The list of upper body keypoint indices
+ lower_body_ids (list): The list of lower body keypoint indices
+
+ Returns:
+ list[list[int] | None]: The selected half-body keypoint indices
+ of each instance. ``None`` means not applying half-body transform.
+ """
+
+ half_body_ids = []
+
+ for visible in keypoints_visible:
+ if visible.sum() < self.min_total_keypoints:
+ indices = None
+ elif np.random.rand() > self.prob:
+ indices = None
+ else:
+ upper_valid_ids = [i for i in upper_body_ids if visible[i] > 0]
+ lower_valid_ids = [i for i in lower_body_ids if visible[i] > 0]
+
+ num_upper = len(upper_valid_ids)
+ num_lower = len(lower_valid_ids)
+
+ prefer_upper = np.random.rand() < self.upper_prioritized_prob
+ if (num_upper < self.min_upper_keypoints
+ and num_lower < self.min_lower_keypoints):
+ indices = None
+ elif num_lower < self.min_lower_keypoints:
+ indices = upper_valid_ids
+ elif num_upper < self.min_upper_keypoints:
+ indices = lower_valid_ids
+ else:
+ indices = (
+ upper_valid_ids if prefer_upper else lower_valid_ids)
+
+ half_body_ids.append(indices)
+
+ return half_body_ids
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`HalfBodyTransform`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+
+ half_body_ids = self._random_select_half_body(
+ keypoints_visible=results['keypoints_visible'],
+ upper_body_ids=results['upper_body_ids'],
+ lower_body_ids=results['lower_body_ids'])
+
+ bbox_center = []
+ bbox_scale = []
+
+ for i, indices in enumerate(half_body_ids):
+ if indices is None:
+ bbox_center.append(results['bbox_center'][i])
+ bbox_scale.append(results['bbox_scale'][i])
+ else:
+ _center, _scale = self._get_half_body_bbox(
+ results['keypoints'][i], indices)
+ bbox_center.append(_center)
+ bbox_scale.append(_scale)
+
+ results['bbox_center'] = np.stack(bbox_center)
+ results['bbox_scale'] = np.stack(bbox_scale)
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_total_keypoints={self.min_total_keypoints}, '
+ repr_str += f'min_upper_keypoints={self.min_upper_keypoints}, '
+ repr_str += f'min_lower_keypoints={self.min_lower_keypoints}, '
+ repr_str += f'padding={self.padding}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'upper_prioritized_prob={self.upper_prioritized_prob})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomBBoxTransform(BaseTransform):
+ r"""Rnadomly shift, resize and rotate the bounding boxes.
+
+ Required Keys:
+
+ - bbox_center
+ - bbox_scale
+
+ Modified Keys:
+
+ - bbox_center
+ - bbox_scale
+
+ Added Keys:
+ - bbox_rotation
+
+ Args:
+ shift_factor (float): Randomly shift the bbox in range
+ :math:`[-dx, dx]` and :math:`[-dy, dy]` in X and Y directions,
+ where :math:`dx(y) = x(y)_scale \cdot shift_factor` in pixels.
+ Defaults to 0.16
+ shift_prob (float): Probability of applying random shift. Defaults to
+ 0.3
+ scale_factor (Tuple[float, float]): Randomly resize the bbox in range
+ :math:`[scale_factor[0], scale_factor[1]]`. Defaults to (0.5, 1.5)
+ scale_prob (float): Probability of applying random resizing. Defaults
+ to 1.0
+ rotate_factor (float): Randomly rotate the bbox in
+ :math:`[-rotate_factor, rotate_factor]` in degrees. Defaults
+ to 80.0
+ rotate_prob (float): Probability of applying random rotation. Defaults
+ to 0.6
+ """
+
+ def __init__(self,
+ shift_factor: float = 0.16,
+ shift_prob: float = 0.3,
+ scale_factor: Tuple[float, float] = (0.5, 1.5),
+ scale_prob: float = 1.0,
+ rotate_factor: float = 80.0,
+ rotate_prob: float = 0.6) -> None:
+ super().__init__()
+
+ self.shift_factor = shift_factor
+ self.shift_prob = shift_prob
+ self.scale_factor = scale_factor
+ self.scale_prob = scale_prob
+ self.rotate_factor = rotate_factor
+ self.rotate_prob = rotate_prob
+
+ @staticmethod
+ def _truncnorm(low: float = -1.,
+ high: float = 1.,
+ size: tuple = ()) -> np.ndarray:
+ """Sample from a truncated normal distribution."""
+ return truncnorm.rvs(low, high, size=size).astype(np.float32)
+
+ @cache_randomness
+ def _get_transform_params(self, num_bboxes: int) -> Tuple:
+ """Get random transform parameters.
+
+ Args:
+ num_bboxes (int): The number of bboxes
+
+ Returns:
+ tuple:
+ - offset (np.ndarray): Offset factor of each bbox in shape (n, 2)
+ - scale (np.ndarray): Scaling factor of each bbox in shape (n, 1)
+ - rotate (np.ndarray): Rotation degree of each bbox in shape (n,)
+ """
+ # Get shift parameters
+ offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor
+ offset = np.where(
+ np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.)
+
+ # Get scaling parameters
+ scale_min, scale_max = self.scale_factor
+ mu = (scale_max + scale_min) * 0.5
+ sigma = (scale_max - scale_min) * 0.5
+ scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu
+ scale = np.where(
+ np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.)
+
+ # Get rotation parameters
+ rotate = self._truncnorm(size=(num_bboxes, )) * self.rotate_factor
+ rotate = np.where(
+ np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.)
+
+ return offset, scale, rotate
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`RandomBboxTransform`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+ bbox_scale = results['bbox_scale']
+ num_bboxes = bbox_scale.shape[0]
+
+ offset, scale, rotate = self._get_transform_params(num_bboxes)
+
+ results['bbox_center'] += offset * bbox_scale
+ results['bbox_scale'] *= scale
+ results['bbox_rotation'] = rotate
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(shift_prob={self.shift_prob}, '
+ repr_str += f'shift_factor={self.shift_factor}, '
+ repr_str += f'scale_prob={self.scale_prob}, '
+ repr_str += f'scale_factor={self.scale_factor}, '
+ repr_str += f'rotate_prob={self.rotate_prob}, '
+ repr_str += f'rotate_factor={self.rotate_factor})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+@avoid_cache_randomness
+class Albumentation(BaseTransform):
+ """Albumentation augmentation (pixel-level transforms only).
+
+ Adds custom pixel-level transformations from Albumentations library.
+ Please visit `https://albumentations.ai/docs/`
+ to get more information.
+
+ Note: we only support pixel-level transforms.
+ Please visit `https://github.com/albumentations-team/`
+ `albumentations#pixel-level-transforms`
+ to get more information about pixel-level transforms.
+
+ Required Keys:
+
+ - img
+
+ Modified Keys:
+
+ - img
+
+ Args:
+ transforms (List[dict]): A list of Albumentation transforms.
+ An example of ``transforms`` is as followed:
+ .. code-block:: python
+
+ [
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+ keymap (dict | None): key mapping from ``input key`` to
+ ``albumentation-style key``.
+ Defaults to None, which will use {'img': 'image'}.
+ """
+
+ def __init__(self,
+ transforms: List[dict],
+ keymap: Optional[dict] = None) -> None:
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+
+ self.transforms = transforms
+
+ self.aug = albumentations.Compose(
+ [self.albu_builder(t) for t in self.transforms])
+
+ if not keymap:
+ self.keymap_to_albu = {
+ 'img': 'image',
+ }
+ else:
+ self.keymap_to_albu = keymap
+
+ def albu_builder(self, cfg: dict) -> albumentations:
+ """Import a module from albumentations.
+
+ It resembles some of :func:`build_from_cfg` logic.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+
+ Returns:
+ albumentations.BasicTransform: The constructed transform object
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmengine.is_str(obj_type):
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+ rank, _ = get_dist_info()
+ if rank == 0 and not hasattr(
+ albumentations.augmentations.transforms, obj_type):
+ warnings.warn(
+ f'{obj_type} is not pixel-level transformations. '
+ 'Please use with caution.')
+ obj_cls = getattr(albumentations, obj_type)
+ else:
+ raise TypeError(f'type must be a str, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(transform)
+ for transform in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ def transform(self, results: dict) -> dict:
+ """The transform function of :class:`Albumentation` to apply
+ albumentations transforms.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Return:
+ dict: updated result dict.
+ """
+ # map result dict to albumentations format
+ results_albu = {}
+ for k, v in self.keymap_to_albu.items():
+ assert k in results, \
+ f'The `{k}` is required to perform albumentations transforms'
+ results_albu[v] = results[k]
+
+ # Apply albumentations transforms
+ results_albu = self.aug(**results_albu)
+
+ # map the albu results back to the original format
+ for k, v in self.keymap_to_albu.items():
+ results[k] = results_albu[v]
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class PhotometricDistortion(BaseTransform):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ 8. randomly swap channels
+
+ Required Keys:
+
+ - img
+
+ Modified Keys:
+
+ - img
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta: int = 32,
+ contrast_range: Sequence[Number] = (0.5, 1.5),
+ saturation_range: Sequence[Number] = (0.5, 1.5),
+ hue_delta: int = 18) -> None:
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ @cache_randomness
+ def _random_flags(self) -> Sequence[Number]:
+ """Generate the random flags for subsequent transforms.
+
+ Returns:
+ Sequence[Number]: a sequence of numbers that indicate whether to
+ do the corresponding transforms.
+ """
+ # contrast_mode == 0 --> do random contrast first
+ # contrast_mode == 1 --> do random contrast last
+ contrast_mode = np.random.randint(2)
+ # whether to apply brightness distortion
+ brightness_flag = np.random.randint(2)
+ # whether to apply contrast distortion
+ contrast_flag = np.random.randint(2)
+ # the mode to convert color from BGR to HSV
+ hsv_mode = np.random.randint(4)
+ # whether to apply channel swap
+ swap_flag = np.random.randint(2)
+
+ # the beta in `self._convert` to be added to image array
+ # in brightness distortion
+ brightness_beta = np.random.uniform(-self.brightness_delta,
+ self.brightness_delta)
+ # the alpha in `self._convert` to be multiplied to image array
+ # in contrast distortion
+ contrast_alpha = np.random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ # the alpha in `self._convert` to be multiplied to image array
+ # in saturation distortion to hsv-formatted img
+ saturation_alpha = np.random.uniform(self.saturation_lower,
+ self.saturation_upper)
+ # delta of hue to add to image array in hue distortion
+ hue_delta = np.random.randint(-self.hue_delta, self.hue_delta)
+ # the random permutation of channel order
+ swap_channel_order = np.random.permutation(3)
+
+ return (contrast_mode, brightness_flag, contrast_flag, hsv_mode,
+ swap_flag, brightness_beta, contrast_alpha, saturation_alpha,
+ hue_delta, swap_channel_order)
+
+ def _convert(self,
+ img: np.ndarray,
+ alpha: float = 1,
+ beta: float = 0) -> np.ndarray:
+ """Multiple with alpha and add beta with clip.
+
+ Args:
+ img (np.ndarray): The image array.
+ alpha (float): The random multiplier.
+ beta (float): The random offset.
+
+ Returns:
+ np.ndarray: The updated image array.
+ """
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+
+ def transform(self, results: dict) -> dict:
+ """The transform function of :class:`PhotometricDistortion` to perform
+ photometric distortion on images.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ assert 'img' in results, '`img` is not found in results'
+ img = results['img']
+
+ (contrast_mode, brightness_flag, contrast_flag, hsv_mode, swap_flag,
+ brightness_beta, contrast_alpha, saturation_alpha, hue_delta,
+ swap_channel_order) = self._random_flags()
+
+ # random brightness distortion
+ if brightness_flag:
+ img = self._convert(img, beta=brightness_beta)
+
+ # contrast_mode == 0 --> do random contrast first
+ # contrast_mode == 1 --> do random contrast last
+ if contrast_mode == 1:
+ if contrast_flag:
+ img = self._convert(img, alpha=contrast_alpha)
+
+ if hsv_mode:
+ # random saturation/hue distortion
+ img = mmcv.bgr2hsv(img)
+ if hsv_mode == 1 or hsv_mode == 3:
+ # apply saturation distortion to hsv-formatted img
+ img[:, :, 1] = self._convert(
+ img[:, :, 1], alpha=saturation_alpha)
+ if hsv_mode == 2 or hsv_mode == 3:
+ # apply hue distortion to hsv-formatted img
+ img[:, :, 0] = img[:, :, 0].astype(int) + hue_delta
+ img = mmcv.hsv2bgr(img)
+
+ if contrast_mode == 1:
+ if contrast_flag:
+ img = self._convert(img, alpha=contrast_alpha)
+
+ # randomly swap channels
+ if swap_flag:
+ img = img[..., swap_channel_order]
+
+ results['img'] = img
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
+ f'contrast_range=({self.contrast_lower}, '
+ f'{self.contrast_upper}), '
+ f'saturation_range=({self.saturation_lower}, '
+ f'{self.saturation_upper}), '
+ f'hue_delta={self.hue_delta})')
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class GenerateTarget(BaseTransform):
+ """Encode keypoints into Target.
+
+ The generated target is usually the supervision signal of the model
+ learning, e.g. heatmaps or regression labels.
+
+ Required Keys:
+
+ - keypoints
+ - keypoints_visible
+ - dataset_keypoint_weights
+
+ Added Keys:
+
+ - The keys of the encoded items from the codec will be updated into
+ the results, e.g. ``'heatmaps'`` or ``'keypoint_weights'``. See
+ the specific codec for more details.
+
+ Args:
+ encoder (dict | list[dict]): The codec config for keypoint encoding.
+ Both single encoder and multiple encoders (given as a list) are
+ supported
+ multilevel (bool): Determine the method to handle multiple encoders.
+ If ``multilevel==True``, generate multilevel targets from a group
+ of encoders of the same type (e.g. multiple :class:`MSRAHeatmap`
+ encoders with different sigma values); If ``multilevel==False``,
+ generate combined targets from a group of different encoders. This
+ argument will have no effect in case of single encoder. Defaults
+ to ``False``
+ use_dataset_keypoint_weights (bool): Whether use the keypoint weights
+ from the dataset meta information. Defaults to ``False``
+ target_type (str, deprecated): This argument is deprecated and has no
+ effect. Defaults to ``None``
+ """
+
+ def __init__(self,
+ encoder: MultiConfig,
+ target_type: Optional[str] = None,
+ multilevel: bool = False,
+ use_dataset_keypoint_weights: bool = False) -> None:
+ super().__init__()
+
+ if target_type is not None:
+ rank, _ = get_dist_info()
+ if rank == 0:
+ warnings.warn(
+ 'The argument `target_type` is deprecated in'
+ ' GenerateTarget. The target type and encoded '
+ 'keys will be determined by encoder(s).',
+ DeprecationWarning)
+
+ self.encoder_cfg = deepcopy(encoder)
+ self.multilevel = multilevel
+ self.use_dataset_keypoint_weights = use_dataset_keypoint_weights
+
+ if isinstance(self.encoder_cfg, list):
+ self.encoder = [
+ KEYPOINT_CODECS.build(cfg) for cfg in self.encoder_cfg
+ ]
+ else:
+ assert not self.multilevel, (
+ 'Need multiple encoder configs if ``multilevel==True``')
+ self.encoder = KEYPOINT_CODECS.build(self.encoder_cfg)
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`GenerateTarget`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+ """
+
+ if results.get('transformed_keypoints', None) is not None:
+ # use keypoints transformed by TopdownAffine
+ keypoints = results['transformed_keypoints']
+ elif results.get('keypoints', None) is not None:
+ # use original keypoints
+ keypoints = results['keypoints']
+ else:
+ raise ValueError(
+ 'GenerateTarget requires \'transformed_keypoints\' or'
+ ' \'keypoints\' in the results.')
+
+ keypoints_visible = results['keypoints_visible']
+
+ # Encoded items from the encoder(s) will be updated into the results.
+ # Please refer to the document of the specific codec for details about
+ # encoded items.
+ if not isinstance(self.encoder, list):
+ # For single encoding, the encoded items will be directly added
+ # into results.
+ auxiliary_encode_kwargs = {
+ key: results[key]
+ for key in self.encoder.auxiliary_encode_keys
+ }
+ encoded = self.encoder.encode(
+ keypoints=keypoints,
+ keypoints_visible=keypoints_visible,
+ **auxiliary_encode_kwargs)
+
+ else:
+ encoded_list = []
+ for _encoder in self.encoder:
+ auxiliary_encode_kwargs = {
+ key: results[key]
+ for key in _encoder.auxiliary_encode_keys
+ }
+ encoded_list.append(
+ _encoder.encode(
+ keypoints=keypoints,
+ keypoints_visible=keypoints_visible,
+ **auxiliary_encode_kwargs))
+
+ if self.multilevel:
+ # For multilevel encoding, the encoded items from each encoder
+ # should have the same keys.
+
+ keys = encoded_list[0].keys()
+ if not all(_encoded.keys() == keys
+ for _encoded in encoded_list):
+ raise ValueError(
+ 'Encoded items from all encoders must have the same '
+ 'keys if ``multilevel==True``.')
+
+ encoded = {
+ k: [_encoded[k] for _encoded in encoded_list]
+ for k in keys
+ }
+
+ else:
+ # For combined encoding, the encoded items from different
+ # encoders should have no overlapping items, except for
+ # `keypoint_weights`. If multiple `keypoint_weights` are given,
+ # they will be multiplied as the final `keypoint_weights`.
+
+ encoded = dict()
+ keypoint_weights = []
+
+ for _encoded in encoded_list:
+ for key, value in _encoded.items():
+ if key == 'keypoint_weights':
+ keypoint_weights.append(value)
+ elif key not in encoded:
+ encoded[key] = value
+ else:
+ raise ValueError(
+ f'Overlapping item "{key}" from multiple '
+ 'encoders, which is not supported when '
+ '``multilevel==False``')
+
+ if keypoint_weights:
+ encoded['keypoint_weights'] = keypoint_weights
+
+ if self.use_dataset_keypoint_weights and 'keypoint_weights' in encoded:
+ if isinstance(encoded['keypoint_weights'], list):
+ for w in encoded['keypoint_weights']:
+ w *= results['dataset_keypoint_weights']
+ else:
+ encoded['keypoint_weights'] *= results[
+ 'dataset_keypoint_weights']
+
+ results.update(encoded)
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += (f'(encoder={str(self.encoder_cfg)}, ')
+ repr_str += ('use_dataset_keypoint_weights='
+ f'{self.use_dataset_keypoint_weights})')
+ return repr_str
diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py
new file mode 100644
index 0000000000000000000000000000000000000000..38dcea09946eaafe396f4ba8e23cafa3d2da7ecc
--- /dev/null
+++ b/mmpose/datasets/transforms/converting.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple, Union
+
+import numpy as np
+from mmcv.transforms import BaseTransform
+
+from mmpose.registry import TRANSFORMS
+
+
+@TRANSFORMS.register_module()
+class KeypointConverter(BaseTransform):
+ """Change the order of keypoints according to the given mapping.
+
+ Required Keys:
+
+ - keypoints
+ - keypoints_visible
+
+ Modified Keys:
+
+ - keypoints
+ - keypoints_visible
+
+ Args:
+ num_keypoints (int): The number of keypoints in target dataset.
+ mapping (list): A list containing mapping indexes. Each element has
+ format (source_index, target_index)
+
+ Example:
+ >>> import numpy as np
+ >>> # case 1: 1-to-1 mapping
+ >>> # (0, 0) means target[0] = source[0]
+ >>> self = KeypointConverter(
+ >>> num_keypoints=3,
+ >>> mapping=[
+ >>> (0, 0), (1, 1), (2, 2), (3, 3)
+ >>> ])
+ >>> results = dict(
+ >>> keypoints=np.arange(34).reshape(2, 3, 2),
+ >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
+ >>> results = self(results)
+ >>> assert np.equal(results['keypoints'],
+ >>> np.arange(34).reshape(2, 3, 2)).all()
+ >>> assert np.equal(results['keypoints_visible'],
+ >>> np.arange(34).reshape(2, 3, 2) % 2).all()
+ >>>
+ >>> # case 2: 2-to-1 mapping
+ >>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2
+ >>> self = KeypointConverter(
+ >>> num_keypoints=3,
+ >>> mapping=[
+ >>> ((1, 2), 0), (1, 1), (2, 2)
+ >>> ])
+ >>> results = dict(
+ >>> keypoints=np.arange(34).reshape(2, 3, 2),
+ >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2)
+ >>> results = self(results)
+ """
+
+ def __init__(self, num_keypoints: int,
+ mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
+ int]]]):
+ self.num_keypoints = num_keypoints
+ self.mapping = mapping
+ source_index, target_index = zip(*mapping)
+
+ src1, src2 = [], []
+ interpolation = False
+ for x in source_index:
+ if isinstance(x, (list, tuple)):
+ assert len(x) == 2, 'source_index should be a list/tuple of ' \
+ 'length 2'
+ src1.append(x[0])
+ src2.append(x[1])
+ interpolation = True
+ else:
+ src1.append(x)
+ src2.append(x)
+
+ # When paired source_indexes are input,
+ # keep a self.source_index2 for interpolation
+ if interpolation:
+ self.source_index2 = src2
+
+ self.source_index = src1
+ self.target_index = target_index
+ self.interpolation = interpolation
+
+ def transform(self, results: dict) -> dict:
+ num_instances = results['keypoints'].shape[0]
+
+ keypoints = np.zeros((num_instances, self.num_keypoints, 2))
+ keypoints_visible = np.zeros((num_instances, self.num_keypoints))
+
+ # When paired source_indexes are input,
+ # perform interpolation with self.source_index and self.source_index2
+ if self.interpolation:
+ keypoints[:, self.target_index] = 0.5 * (
+ results['keypoints'][:, self.source_index] +
+ results['keypoints'][:, self.source_index2])
+
+ keypoints_visible[:, self.target_index] = results[
+ 'keypoints_visible'][:, self.source_index] * \
+ results['keypoints_visible'][:, self.source_index2]
+ else:
+ keypoints[:,
+ self.target_index] = results['keypoints'][:, self.
+ source_index]
+ keypoints_visible[:, self.target_index] = results[
+ 'keypoints_visible'][:, self.source_index]
+
+ results['keypoints'] = keypoints
+ results['keypoints_visible'] = keypoints_visible
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_keypoints={self.num_keypoints}, '\
+ f'mapping={self.mapping})'
+ return repr_str
diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9ad522f295d86daec05d91d95e567af6ad9878
--- /dev/null
+++ b/mmpose/datasets/transforms/formatting.py
@@ -0,0 +1,218 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence, Union
+
+import numpy as np
+import torch
+from mmcv.transforms import BaseTransform
+from mmengine.structures import InstanceData, PixelData
+from mmengine.utils import is_seq_of
+
+from mmpose.registry import TRANSFORMS
+from mmpose.structures import MultilevelPixelData, PoseDataSample
+
+
+def image_to_tensor(img: Union[np.ndarray,
+ Sequence[np.ndarray]]) -> torch.torch.Tensor:
+ """Translate image or sequence of images to tensor. Multiple image tensors
+ will be stacked.
+
+ Args:
+ value (np.ndarray | Sequence[np.ndarray]): The original image or
+ image sequence
+
+ Returns:
+ torch.Tensor: The output tensor.
+ """
+
+ if isinstance(img, np.ndarray):
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+
+ img = np.ascontiguousarray(img)
+ tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
+ else:
+ assert is_seq_of(img, np.ndarray)
+ tensor = torch.stack([image_to_tensor(_img) for _img in img])
+
+ return tensor
+
+
+@TRANSFORMS.register_module()
+class PackPoseInputs(BaseTransform):
+ """Pack the inputs data for pose estimation.
+
+ The ``img_meta`` item is always populated. The contents of the
+ ``img_meta`` dictionary depends on ``meta_keys``. By default it includes:
+
+ - ``id``: id of the data sample
+
+ - ``img_id``: id of the image
+
+ - ``'category_id'``: the id of the instance category
+
+ - ``img_path``: path to the image file
+
+ - ``crowd_index`` (optional): measure the crowding level of an image,
+ defined in CrowdPose dataset
+
+ - ``ori_shape``: original shape of the image as a tuple (h, w, c)
+
+ - ``img_shape``: shape of the image input to the network as a tuple \
+ (h, w). Note that images may be zero padded on the \
+ bottom/right if the batch tensor is larger than this shape.
+
+ - ``input_size``: the input size to the network
+
+ - ``flip``: a boolean indicating if image flip transform was used
+
+ - ``flip_direction``: the flipping direction
+
+ - ``flip_indices``: the indices of each keypoint's symmetric keypoint
+
+ - ``raw_ann_info`` (optional): raw annotation of the instance(s)
+
+ Args:
+ meta_keys (Sequence[str], optional): Meta keys which will be stored in
+ :obj: `PoseDataSample` as meta info. Defaults to ``('id',
+ 'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape',
+ 'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip',
+ 'flip_direction', 'flip_indices', 'raw_ann_info')``
+ """
+
+ # items in `instance_mapping_table` will be directly packed into
+ # PoseDataSample.gt_instances without converting to Tensor
+ instance_mapping_table = {
+ 'bbox': 'bboxes',
+ 'head_size': 'head_size',
+ 'bbox_center': 'bbox_centers',
+ 'bbox_scale': 'bbox_scales',
+ 'bbox_score': 'bbox_scores',
+ 'keypoints': 'keypoints',
+ 'keypoints_visible': 'keypoints_visible',
+ }
+
+ # items in `label_mapping_table` will be packed into
+ # PoseDataSample.gt_instance_labels and converted to Tensor. These items
+ # will be used for computing losses
+ label_mapping_table = {
+ 'keypoint_labels': 'keypoint_labels',
+ 'keypoint_x_labels': 'keypoint_x_labels',
+ 'keypoint_y_labels': 'keypoint_y_labels',
+ 'keypoint_weights': 'keypoint_weights',
+ 'instance_coords': 'instance_coords'
+ }
+
+ # items in `field_mapping_table` will be packed into
+ # PoseDataSample.gt_fields and converted to Tensor. These items will be
+ # used for computing losses
+ field_mapping_table = {
+ 'heatmaps': 'heatmaps',
+ 'instance_heatmaps': 'instance_heatmaps',
+ 'heatmap_mask': 'heatmap_mask',
+ 'heatmap_weights': 'heatmap_weights',
+ 'displacements': 'displacements',
+ 'displacement_weights': 'displacement_weights',
+ }
+
+ def __init__(self,
+ meta_keys=('id', 'img_id', 'img_path', 'category_id',
+ 'crowd_index', 'ori_shape', 'img_shape',
+ 'input_size', 'input_center', 'input_scale',
+ 'flip', 'flip_direction', 'flip_indices',
+ 'raw_ann_info'),
+ pack_transformed=False):
+ self.meta_keys = meta_keys
+ self.pack_transformed = pack_transformed
+
+ def transform(self, results: dict) -> dict:
+ """Method to pack the input data.
+
+ Args:
+ results (dict): Result dict from the data pipeline.
+
+ Returns:
+ dict:
+
+ - 'inputs' (obj:`torch.Tensor`): The forward data of models.
+ - 'data_samples' (obj:`PoseDataSample`): The annotation info of the
+ sample.
+ """
+ # Pack image(s)
+ if 'img' in results:
+ img = results['img']
+ img_tensor = image_to_tensor(img)
+
+ data_sample = PoseDataSample()
+
+ # pack instance data
+ gt_instances = InstanceData()
+ for key, packed_key in self.instance_mapping_table.items():
+ if key in results:
+ gt_instances.set_field(results[key], packed_key)
+
+ # pack `transformed_keypoints` for visualizing data transform
+ # and augmentation results
+ if self.pack_transformed and 'transformed_keypoints' in results:
+ gt_instances.set_field(results['transformed_keypoints'],
+ 'transformed_keypoints')
+
+ data_sample.gt_instances = gt_instances
+
+ # pack instance labels
+ gt_instance_labels = InstanceData()
+ for key, packed_key in self.label_mapping_table.items():
+ if key in results:
+ if isinstance(results[key], list):
+ # A list of labels is usually generated by combined
+ # multiple encoders (See ``GenerateTarget`` in
+ # mmpose/datasets/transforms/common_transforms.py)
+ # In this case, labels in list should have the same
+ # shape and will be stacked.
+ _labels = np.stack(results[key])
+ gt_instance_labels.set_field(_labels, packed_key)
+ else:
+ gt_instance_labels.set_field(results[key], packed_key)
+ data_sample.gt_instance_labels = gt_instance_labels.to_tensor()
+
+ # pack fields
+ gt_fields = None
+ for key, packed_key in self.field_mapping_table.items():
+ if key in results:
+ if isinstance(results[key], list):
+ if gt_fields is None:
+ gt_fields = MultilevelPixelData()
+ else:
+ assert isinstance(
+ gt_fields, MultilevelPixelData
+ ), 'Got mixed single-level and multi-level pixel data.'
+ else:
+ if gt_fields is None:
+ gt_fields = PixelData()
+ else:
+ assert isinstance(
+ gt_fields, PixelData
+ ), 'Got mixed single-level and multi-level pixel data.'
+
+ gt_fields.set_field(results[key], packed_key)
+
+ if gt_fields:
+ data_sample.gt_fields = gt_fields.to_tensor()
+
+ img_meta = {k: results[k] for k in self.meta_keys if k in results}
+ data_sample.set_metainfo(img_meta)
+
+ packed_results = dict()
+ packed_results['inputs'] = img_tensor
+ packed_results['data_samples'] = data_sample
+
+ return packed_results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(meta_keys={self.meta_keys})'
+ return repr_str
diff --git a/mmpose/datasets/transforms/loading.py b/mmpose/datasets/transforms/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..28edcb48066b7a7b5534f69e6b3981d31825beca
--- /dev/null
+++ b/mmpose/datasets/transforms/loading.py
@@ -0,0 +1,66 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import numpy as np
+from mmcv.transforms import LoadImageFromFile
+
+from mmpose.registry import TRANSFORMS
+
+
+@TRANSFORMS.register_module()
+class LoadImage(LoadImageFromFile):
+ """Load an image from file or from the np.ndarray in ``results['img']``.
+
+ Required Keys:
+
+ - img_path
+ - img (optional)
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - ori_shape
+ - img_path (optional)
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
+ Defaults to 'color'.
+ imdecode_backend (str): The image decoding backend type. The backend
+ argument for :func:``mmcv.imfrombytes``.
+ See :func:``mmcv.imfrombytes`` for details.
+ Defaults to 'cv2'.
+ backend_args (dict, optional): Arguments to instantiate the preifx of
+ uri corresponding backend. Defaults to None.
+ ignore_empty (bool): Whether to allow loading empty image or file path
+ not existent. Defaults to False.
+ """
+
+ def transform(self, results: dict) -> Optional[dict]:
+ """The transform function of :class:`LoadImage`.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+
+ if 'img' not in results:
+ # Load image from file by :meth:`LoadImageFromFile.transform`
+ results = super().transform(results)
+ else:
+ img = results['img']
+ assert isinstance(img, np.ndarray)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ if 'img_path' not in results:
+ results['img_path'] = None
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+
+ return results
diff --git a/mmpose/datasets/transforms/topdown_transforms.py b/mmpose/datasets/transforms/topdown_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..29aa48eb06576059c8fad4e1d46de4e5061f372f
--- /dev/null
+++ b/mmpose/datasets/transforms/topdown_transforms.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Optional, Tuple
+
+import cv2
+import numpy as np
+from mmcv.transforms import BaseTransform
+from mmengine import is_seq_of
+
+from mmpose.registry import TRANSFORMS
+from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix
+
+
+@TRANSFORMS.register_module()
+class TopdownAffine(BaseTransform):
+ """Get the bbox image as the model input by affine transform.
+
+ Required Keys:
+
+ - img
+ - bbox_center
+ - bbox_scale
+ - bbox_rotation (optional)
+ - keypoints (optional)
+
+ Modified Keys:
+
+ - img
+ - bbox_scale
+
+ Added Keys:
+
+ - input_size
+ - transformed_keypoints
+
+ Args:
+ input_size (Tuple[int, int]): The input image size of the model in
+ [w, h]. The bbox region will be cropped and resize to `input_size`
+ use_udp (bool): Whether use unbiased data processing. See
+ `UDP (CVPR 2020)`_ for details. Defaults to ``False``
+
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
+ """
+
+ def __init__(self,
+ input_size: Tuple[int, int],
+ use_udp: bool = False) -> None:
+ super().__init__()
+
+ assert is_seq_of(input_size, int) and len(input_size) == 2, (
+ f'Invalid input_size {input_size}')
+
+ self.input_size = input_size
+ self.use_udp = use_udp
+
+ @staticmethod
+ def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float):
+ """Reshape the bbox to a fixed aspect ratio.
+
+ Args:
+ bbox_scale (np.ndarray): The bbox scales (w, h) in shape (n, 2)
+ aspect_ratio (float): The ratio of ``w/h``
+
+ Returns:
+ np.darray: The reshaped bbox scales in (n, 2)
+ """
+
+ w, h = np.hsplit(bbox_scale, [1])
+ bbox_scale = np.where(w > h * aspect_ratio,
+ np.hstack([w, w / aspect_ratio]),
+ np.hstack([h * aspect_ratio, h]))
+ return bbox_scale
+
+ def transform(self, results: Dict) -> Optional[dict]:
+ """The transform function of :class:`TopdownAffine`.
+
+ See ``transform()`` method of :class:`BaseTransform` for details.
+
+ Args:
+ results (dict): The result dict
+
+ Returns:
+ dict: The result dict.
+ """
+
+ w, h = self.input_size
+ warp_size = (int(w), int(h))
+
+ # reshape bbox to fixed aspect ratio
+ results['bbox_scale'] = self._fix_aspect_ratio(
+ results['bbox_scale'], aspect_ratio=w / h)
+
+ # TODO: support multi-instance
+ assert results['bbox_center'].shape[0] == 1, (
+ 'Top-down heatmap only supports single instance. Got invalid '
+ f'shape of bbox_center {results["bbox_center"].shape}.')
+
+ center = results['bbox_center'][0]
+ scale = results['bbox_scale'][0]
+ if 'bbox_rotation' in results:
+ rot = results['bbox_rotation'][0]
+ else:
+ rot = 0.
+
+ if self.use_udp:
+ warp_mat = get_udp_warp_matrix(
+ center, scale, rot, output_size=(w, h))
+ else:
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
+
+ if isinstance(results['img'], list):
+ results['img'] = [
+ cv2.warpAffine(
+ img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
+ for img in results['img']
+ ]
+ else:
+ results['img'] = cv2.warpAffine(
+ results['img'], warp_mat, warp_size, flags=cv2.INTER_LINEAR)
+
+ if results.get('keypoints', None) is not None:
+ transformed_keypoints = results['keypoints'].copy()
+ # Only transform (x, y) coordinates
+ transformed_keypoints[..., :2] = cv2.transform(
+ results['keypoints'][..., :2], warp_mat)
+ results['transformed_keypoints'] = transformed_keypoints
+
+ results['input_size'] = (w, h)
+
+ return results
+
+ def __repr__(self) -> str:
+ """print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(input_size={self.input_size}, '
+ repr_str += f'use_udp={self.use_udp})'
+ return repr_str
diff --git a/mmpose/engine/__init__.py b/mmpose/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac85928986718cfc181ac311949c238ea11cf34c
--- /dev/null
+++ b/mmpose/engine/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hooks import * # noqa: F401, F403
+from .optim_wrappers import * # noqa: F401, F403
diff --git a/mmpose/engine/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d8cc2201d5a5be2e91e002781177270fbfc6793
Binary files /dev/null and b/mmpose/engine/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/engine/hooks/__init__.py b/mmpose/engine/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dadb9c5f913e3e61eebb4ee4e246076bb3d45dd9
--- /dev/null
+++ b/mmpose/engine/hooks/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ema_hook import ExpMomentumEMA
+from .visualization_hook import PoseVisualizationHook
+
+__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA']
diff --git a/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ab45eaf91278f1165c44221818f8694926a65a8
Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f638feee9b91007eaa02ccdce18e6d5bca270020
Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/ema_hook.cpython-38.pyc differ
diff --git a/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc b/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9590b9dceaf9f8ae48bdb96438e73c465772526d
Binary files /dev/null and b/mmpose/engine/hooks/__pycache__/visualization_hook.cpython-38.pyc differ
diff --git a/mmpose/engine/hooks/ema_hook.py b/mmpose/engine/hooks/ema_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd1a689f96f49c33059ec1e4afbe7b01b85164f9
--- /dev/null
+++ b/mmpose/engine/hooks/ema_hook.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from mmengine.model import ExponentialMovingAverage
+from torch import Tensor
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class ExpMomentumEMA(ExponentialMovingAverage):
+ """Exponential moving average (EMA) with exponential momentum strategy,
+ which is used in YOLOX.
+
+ Ported from ` the implementation of MMDetection
+ `_.
+
+ Args:
+ model (nn.Module): The model to be averaged.
+ momentum (float): The momentum used for updating ema parameter.
+ Ema's parameter are updated with the formula:
+ `averaged_param = (1-momentum) * averaged_param + momentum *
+ source_param`. Defaults to 0.0002.
+ gamma (int): Use a larger momentum early in training and gradually
+ annealing to a smaller value to update the ema model smoothly. The
+ momentum is calculated as
+ `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
+ Defaults to 2000.
+ interval (int): Interval between two updates. Defaults to 1.
+ device (torch.device, optional): If provided, the averaged model will
+ be stored on the :attr:`device`. Defaults to None.
+ update_buffers (bool): if True, it will compute running averages for
+ both the parameters and the buffers of the model. Defaults to
+ False.
+ """
+
+ def __init__(self,
+ model: nn.Module,
+ momentum: float = 0.0002,
+ gamma: int = 2000,
+ interval=1,
+ device: Optional[torch.device] = None,
+ update_buffers: bool = False) -> None:
+ super().__init__(
+ model=model,
+ momentum=momentum,
+ interval=interval,
+ device=device,
+ update_buffers=update_buffers)
+ assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
+ self.gamma = gamma
+
+ def avg_func(self, averaged_param: Tensor, source_param: Tensor,
+ steps: int) -> None:
+ """Compute the moving average of the parameters using the exponential
+ momentum strategy.
+
+ Args:
+ averaged_param (Tensor): The averaged parameters.
+ source_param (Tensor): The source parameters.
+ steps (int): The number of times the parameters have been
+ updated.
+ """
+ momentum = (1 - self.momentum) * math.exp(
+ -float(1 + steps) / self.gamma) + self.momentum
+ averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
diff --git a/mmpose/engine/hooks/visualization_hook.py b/mmpose/engine/hooks/visualization_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..24b845f282291bfd60d2a1a507ce42d586a35073
--- /dev/null
+++ b/mmpose/engine/hooks/visualization_hook.py
@@ -0,0 +1,168 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import warnings
+from typing import Optional, Sequence
+
+import mmcv
+import mmengine
+import mmengine.fileio as fileio
+from mmengine.hooks import Hook
+from mmengine.runner import Runner
+from mmengine.visualization import Visualizer
+
+from mmpose.registry import HOOKS
+from mmpose.structures import PoseDataSample, merge_data_samples
+
+
+@HOOKS.register_module()
+class PoseVisualizationHook(Hook):
+ """Pose Estimation Visualization Hook. Used to visualize validation and
+ testing process prediction results.
+
+ In the testing phase:
+
+ 1. If ``show`` is True, it means that only the prediction results are
+ visualized without storing data, so ``vis_backends`` needs to
+ be excluded.
+ 2. If ``out_dir`` is specified, it means that the prediction results
+ need to be saved to ``out_dir``. In order to avoid vis_backends
+ also storing data, so ``vis_backends`` needs to be excluded.
+ 3. ``vis_backends`` takes effect if the user does not specify ``show``
+ and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or
+ TensorboardVisBackend to store the prediction result in Wandb or
+ Tensorboard.
+
+ Args:
+ enable (bool): whether to draw prediction results. If it is False,
+ it means that no drawing will be done. Defaults to False.
+ interval (int): The interval of visualization. Defaults to 50.
+ score_thr (float): The threshold to visualize the bboxes
+ and masks. Defaults to 0.3.
+ show (bool): Whether to display the drawn image. Default to False.
+ wait_time (float): The interval of show (s). Defaults to 0.
+ out_dir (str, optional): directory where painted images
+ will be saved in testing process.
+ backend_args (dict, optional): Arguments to instantiate the preifx of
+ uri corresponding backend. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ enable: bool = False,
+ interval: int = 50,
+ kpt_thr: float = 0.3,
+ show: bool = False,
+ wait_time: float = 0.,
+ out_dir: Optional[str] = None,
+ backend_args: Optional[dict] = None,
+ ):
+ self._visualizer: Visualizer = Visualizer.get_current_instance()
+ self.interval = interval
+ self.kpt_thr = kpt_thr
+ self.show = show
+ if self.show:
+ # No need to think about vis backends.
+ self._visualizer._vis_backends = {}
+ warnings.warn('The show is True, it means that only '
+ 'the prediction results are visualized '
+ 'without storing data, so vis_backends '
+ 'needs to be excluded.')
+
+ self.wait_time = wait_time
+ self.enable = enable
+ self.out_dir = out_dir
+ self._test_index = 0
+ self.backend_args = backend_args
+
+ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
+ outputs: Sequence[PoseDataSample]) -> None:
+ """Run after every ``self.interval`` validation iterations.
+
+ Args:
+ runner (:obj:`Runner`): The runner of the validation process.
+ batch_idx (int): The index of the current batch in the val loop.
+ data_batch (dict): Data from dataloader.
+ outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model.
+ """
+ if self.enable is False:
+ return
+
+ self._visualizer.set_dataset_meta(runner.val_evaluator.dataset_meta)
+
+ # There is no guarantee that the same batch of images
+ # is visualized for each evaluation.
+ total_curr_iter = runner.iter + batch_idx
+
+ # Visualize only the first data
+ img_path = data_batch['data_samples'][0].get('img_path')
+ img_bytes = fileio.get(img_path, backend_args=self.backend_args)
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
+ data_sample = outputs[0]
+
+ # revert the heatmap on the original image
+ data_sample = merge_data_samples([data_sample])
+
+ if total_curr_iter % self.interval == 0:
+ self._visualizer.add_datasample(
+ os.path.basename(img_path) if self.show else 'val_img',
+ img,
+ data_sample=data_sample,
+ draw_gt=False,
+ draw_bbox=True,
+ draw_heatmap=True,
+ show=self.show,
+ wait_time=self.wait_time,
+ kpt_thr=self.kpt_thr,
+ step=total_curr_iter)
+
+ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
+ outputs: Sequence[PoseDataSample]) -> None:
+ """Run after every testing iterations.
+
+ Args:
+ runner (:obj:`Runner`): The runner of the testing process.
+ batch_idx (int): The index of the current batch in the test loop.
+ data_batch (dict): Data from dataloader.
+ outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model.
+ """
+ if self.enable is False:
+ return
+
+ if self.out_dir is not None:
+ self.out_dir = os.path.join(runner.work_dir, runner.timestamp,
+ self.out_dir)
+ mmengine.mkdir_or_exist(self.out_dir)
+
+ self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta)
+
+ for data_sample in outputs:
+ self._test_index += 1
+
+ img_path = data_sample.get('img_path')
+ img_bytes = fileio.get(img_path, backend_args=self.backend_args)
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
+ data_sample = merge_data_samples([data_sample])
+
+ out_file = None
+ if self.out_dir is not None:
+ out_file_name, postfix = os.path.basename(img_path).rsplit(
+ '.', 1)
+ index = len([
+ fname for fname in os.listdir(self.out_dir)
+ if fname.startswith(out_file_name)
+ ])
+ out_file = f'{out_file_name}_{index}.{postfix}'
+ out_file = os.path.join(self.out_dir, out_file)
+
+ self._visualizer.add_datasample(
+ os.path.basename(img_path) if self.show else 'test_img',
+ img,
+ data_sample=data_sample,
+ show=self.show,
+ draw_gt=False,
+ draw_bbox=True,
+ draw_heatmap=True,
+ wait_time=self.wait_time,
+ kpt_thr=self.kpt_thr,
+ out_file=out_file,
+ step=self._test_index)
diff --git a/mmpose/engine/optim_wrappers/__init__.py b/mmpose/engine/optim_wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c0b1f533ae88e0d9914bda0bd3e79532c779339
--- /dev/null
+++ b/mmpose/engine/optim_wrappers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .layer_decay_optim_wrapper import LayerDecayOptimWrapperConstructor
+
+__all__ = ['LayerDecayOptimWrapperConstructor']
diff --git a/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc b/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6409ffa2e0dadc364d2902061eef3ec985f6dcb7
Binary files /dev/null and b/mmpose/engine/optim_wrappers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc b/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb88ac3eb2a65d1220b46853498dcd540a049e21
Binary files /dev/null and b/mmpose/engine/optim_wrappers/__pycache__/layer_decay_optim_wrapper.cpython-38.pyc differ
diff --git a/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py b/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..6513e5593d98e9aa77a2795529ddeb538b6099c3
--- /dev/null
+++ b/mmpose/engine/optim_wrappers/layer_decay_optim_wrapper.py
@@ -0,0 +1,73 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.dist.utils import get_dist_info
+from mmengine.optim import DefaultOptimWrapperConstructor
+from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS
+
+
+def get_num_layer_for_vit(var_name, num_max_layer):
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.patch_embed'):
+ return 0
+ elif var_name.startswith('backbone.layers'):
+ layer_id = int(var_name.split('.')[2])
+ return layer_id + 1
+ else:
+ return num_max_layer - 1
+
+
+@OPTIM_WRAPPER_CONSTRUCTORS.register_module(force=True)
+class LayerDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
+
+ def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
+ super().__init__(optim_wrapper_cfg, paramwise_cfg=None)
+ self.layer_decay_rate = paramwise_cfg.get('layer_decay_rate', 0.5)
+
+ super().__init__(optim_wrapper_cfg, paramwise_cfg)
+
+ def add_params(self, params, module, prefix='', lr=None):
+ parameter_groups = {}
+ print(self.paramwise_cfg)
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
+ layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
+ weight_decay = self.base_wd
+
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if (len(param.shape) == 1 or name.endswith('.bias')
+ or 'pos_embed' in name):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = weight_decay
+ layer_id = get_num_layer_for_vit(name, num_layers)
+ group_name = 'layer_%d_%s' % (layer_id, group_name)
+
+ if group_name not in parameter_groups:
+ scale = layer_decay_rate**(num_layers - layer_id - 1)
+
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.base_lr,
+ }
+
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+ rank, _ = get_dist_info()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ params.extend(parameter_groups.values())
diff --git a/mmpose/evaluation/__init__.py b/mmpose/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5
--- /dev/null
+++ b/mmpose/evaluation/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .functional import * # noqa: F401,F403
+from .metrics import * # noqa: F401,F403
diff --git a/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3626954e238738013228b15cddb5f4003cc8542
Binary files /dev/null and b/mmpose/evaluation/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/evaluation/functional/__init__.py b/mmpose/evaluation/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c4a8b5d1ebde25eb71d7c82e138b1309af617f3
--- /dev/null
+++ b/mmpose/evaluation/functional/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .keypoint_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
+ keypoint_pck_accuracy,
+ multilabel_classification_accuracy,
+ pose_pck_accuracy, simcc_pck_accuracy)
+from .nms import nms, oks_nms, soft_oks_nms
+
+__all__ = [
+ 'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_nme', 'keypoint_epe',
+ 'pose_pck_accuracy', 'multilabel_classification_accuracy',
+ 'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms'
+]
diff --git a/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2213fd51049355be16f1f57c86488ac91255913
Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cbb7d04a06c5d7b2c4fee75cdcaff54a5fd0f11
Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/keypoint_eval.cpython-38.pyc differ
diff --git a/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc b/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e227dd3ce213c549ad10f315c4ebc24aa780419
Binary files /dev/null and b/mmpose/evaluation/functional/__pycache__/nms.cpython-38.pyc differ
diff --git a/mmpose/evaluation/functional/keypoint_eval.py b/mmpose/evaluation/functional/keypoint_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..060243357b39f9bc7f4689b6ba59320bd86b9d4b
--- /dev/null
+++ b/mmpose/evaluation/functional/keypoint_eval.py
@@ -0,0 +1,320 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple
+
+import numpy as np
+
+from mmpose.codecs.utils import get_heatmap_maximum, get_simcc_maximum
+
+
+def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray,
+ norm_factor: np.ndarray) -> np.ndarray:
+ """Calculate the normalized distances between preds and target.
+
+ Note:
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D (normally, D=2 or D=3)
+
+ Args:
+ preds (np.ndarray[N, K, D]): Predicted keypoint location.
+ gts (np.ndarray[N, K, D]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ norm_factor (np.ndarray[N, D]): Normalization factor.
+ Typical value is heatmap_size.
+
+ Returns:
+ np.ndarray[K, N]: The normalized distances. \
+ If target keypoints are missing, the distance is -1.
+ """
+ N, K, _ = preds.shape
+ # set mask=0 when norm_factor==0
+ _mask = mask.copy()
+ _mask[np.where((norm_factor == 0).sum(1))[0], :] = False
+
+ distances = np.full((N, K), -1, dtype=np.float32)
+ # handle invalid values
+ norm_factor[np.where(norm_factor <= 0)] = 1e6
+ distances[_mask] = np.linalg.norm(
+ ((preds - gts) / norm_factor[:, None, :])[_mask], axis=-1)
+ return distances.T
+
+
+def _distance_acc(distances: np.ndarray, thr: float = 0.5) -> float:
+ """Return the percentage below the distance threshold, while ignoring
+ distances values with -1.
+
+ Note:
+ - instance number: N
+
+ Args:
+ distances (np.ndarray[N, ]): The normalized distances.
+ thr (float): Threshold of the distances.
+
+ Returns:
+ float: Percentage of distances below the threshold. \
+ If all target keypoints are missing, return -1.
+ """
+ distance_valid = distances != -1
+ num_distance_valid = distance_valid.sum()
+ if num_distance_valid > 0:
+ return (distances[distance_valid] < thr).sum() / num_distance_valid
+ return -1
+
+
+def keypoint_pck_accuracy(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray,
+ thr: np.ndarray, norm_factor: np.ndarray) -> tuple:
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
+ averaged accuracy across all keypoints for coordinates.
+
+ Note:
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ - instance number: N
+ - keypoint number: K
+
+ Args:
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ thr (float): Threshold of PCK calculation.
+ norm_factor (np.ndarray[N, 2]): Normalization factor for H&W.
+
+ Returns:
+ tuple: A tuple containing keypoint accuracy.
+
+ - acc (np.ndarray[K]): Accuracy of each keypoint.
+ - avg_acc (float): Averaged accuracy across all keypoints.
+ - cnt (int): Number of valid keypoints.
+ """
+ distances = _calc_distances(pred, gt, mask, norm_factor)
+ acc = np.array([_distance_acc(d, thr) for d in distances])
+ valid_acc = acc[acc >= 0]
+ cnt = len(valid_acc)
+ avg_acc = valid_acc.mean() if cnt > 0 else 0
+ return acc, avg_acc, cnt
+
+
+def keypoint_auc(pred: np.ndarray,
+ gt: np.ndarray,
+ mask: np.ndarray,
+ norm_factor: np.ndarray,
+ num_thrs: int = 20) -> float:
+ """Calculate the Area under curve (AUC) of keypoint PCK accuracy.
+
+ Note:
+ - instance number: N
+ - keypoint number: K
+
+ Args:
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ norm_factor (float): Normalization factor.
+ num_thrs (int): number of thresholds to calculate auc.
+
+ Returns:
+ float: Area under curve (AUC) of keypoint PCK accuracy.
+ """
+ nor = np.tile(np.array([[norm_factor, norm_factor]]), (pred.shape[0], 1))
+ thrs = [1.0 * i / num_thrs for i in range(num_thrs)]
+ avg_accs = []
+ for thr in thrs:
+ _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor)
+ avg_accs.append(avg_acc)
+
+ auc = 0
+ for i in range(num_thrs):
+ auc += 1.0 / num_thrs * avg_accs[i]
+ return auc
+
+
+def keypoint_nme(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray,
+ normalize_factor: np.ndarray) -> float:
+ """Calculate the normalized mean error (NME).
+
+ Note:
+ - instance number: N
+ - keypoint number: K
+
+ Args:
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ normalize_factor (np.ndarray[N, 2]): Normalization factor.
+
+ Returns:
+ float: normalized mean error
+ """
+ distances = _calc_distances(pred, gt, mask, normalize_factor)
+ distance_valid = distances[distances != -1]
+ return distance_valid.sum() / max(1, len(distance_valid))
+
+
+def keypoint_epe(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray) -> float:
+ """Calculate the end-point error.
+
+ Note:
+ - instance number: N
+ - keypoint number: K
+
+ Args:
+ pred (np.ndarray[N, K, 2]): Predicted keypoint location.
+ gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+
+ Returns:
+ float: Average end-point error.
+ """
+
+ distances = _calc_distances(
+ pred, gt, mask,
+ np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
+ distance_valid = distances[distances != -1]
+ return distance_valid.sum() / max(1, len(distance_valid))
+
+
+def pose_pck_accuracy(output: np.ndarray,
+ target: np.ndarray,
+ mask: np.ndarray,
+ thr: float = 0.05,
+ normalize: Optional[np.ndarray] = None) -> tuple:
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
+ averaged accuracy across all keypoints from heatmaps.
+
+ Note:
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ - batch_size: N
+ - num_keypoints: K
+ - heatmap height: H
+ - heatmap width: W
+
+ Args:
+ output (np.ndarray[N, K, H, W]): Model output heatmaps.
+ target (np.ndarray[N, K, H, W]): Groundtruth heatmaps.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ thr (float): Threshold of PCK calculation. Default 0.05.
+ normalize (np.ndarray[N, 2]): Normalization factor for H&W.
+
+ Returns:
+ tuple: A tuple containing keypoint accuracy.
+
+ - np.ndarray[K]: Accuracy of each keypoint.
+ - float: Averaged accuracy across all keypoints.
+ - int: Number of valid keypoints.
+ """
+ N, K, H, W = output.shape
+ if K == 0:
+ return None, 0, 0
+ if normalize is None:
+ normalize = np.tile(np.array([[H, W]]), (N, 1))
+
+ pred, _ = get_heatmap_maximum(output)
+ gt, _ = get_heatmap_maximum(target)
+ return keypoint_pck_accuracy(pred, gt, mask, thr, normalize)
+
+
+def simcc_pck_accuracy(output: Tuple[np.ndarray, np.ndarray],
+ target: Tuple[np.ndarray, np.ndarray],
+ simcc_split_ratio: float,
+ mask: np.ndarray,
+ thr: float = 0.05,
+ normalize: Optional[np.ndarray] = None) -> tuple:
+ """Calculate the pose accuracy of PCK for each individual keypoint and the
+ averaged accuracy across all keypoints from SimCC.
+
+ Note:
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ - instance number: N
+ - keypoint number: K
+
+ Args:
+ output (Tuple[np.ndarray, np.ndarray]): Model predicted SimCC.
+ target (Tuple[np.ndarray, np.ndarray]): Groundtruth SimCC.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ thr (float): Threshold of PCK calculation. Default 0.05.
+ normalize (np.ndarray[N, 2]): Normalization factor for H&W.
+
+ Returns:
+ tuple: A tuple containing keypoint accuracy.
+
+ - np.ndarray[K]: Accuracy of each keypoint.
+ - float: Averaged accuracy across all keypoints.
+ - int: Number of valid keypoints.
+ """
+ pred_x, pred_y = output
+ gt_x, gt_y = target
+
+ N, _, Wx = pred_x.shape
+ _, _, Wy = pred_y.shape
+ W, H = int(Wx / simcc_split_ratio), int(Wy / simcc_split_ratio)
+
+ if normalize is None:
+ normalize = np.tile(np.array([[H, W]]), (N, 1))
+
+ pred_coords, _ = get_simcc_maximum(pred_x, pred_y)
+ pred_coords /= simcc_split_ratio
+ gt_coords, _ = get_simcc_maximum(gt_x, gt_y)
+ gt_coords /= simcc_split_ratio
+
+ return keypoint_pck_accuracy(pred_coords, gt_coords, mask, thr, normalize)
+
+
+def multilabel_classification_accuracy(pred: np.ndarray,
+ gt: np.ndarray,
+ mask: np.ndarray,
+ thr: float = 0.5) -> float:
+ """Get multi-label classification accuracy.
+
+ Note:
+ - batch size: N
+ - label number: L
+
+ Args:
+ pred (np.ndarray[N, L, 2]): model predicted labels.
+ gt (np.ndarray[N, L, 2]): ground-truth labels.
+ mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of
+ ground-truth labels.
+ thr (float): Threshold for calculating accuracy.
+
+ Returns:
+ float: multi-label classification accuracy.
+ """
+ # we only compute accuracy on the samples with ground-truth of all labels.
+ valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0)
+ pred, gt = pred[valid], gt[valid]
+
+ if pred.shape[0] == 0:
+ acc = 0.0 # when no sample is with gt labels, set acc to 0.
+ else:
+ # The classification of a sample is regarded as correct
+ # only if it's correct for all labels.
+ acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean()
+ return acc
diff --git a/mmpose/evaluation/functional/nms.py b/mmpose/evaluation/functional/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..eed4e5cf736585667657a1d4a4ca34f1bc8c0423
--- /dev/null
+++ b/mmpose/evaluation/functional/nms.py
@@ -0,0 +1,327 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
+# and https://github.com/HRNet/DEKR
+# Original licence: Copyright (c) Microsoft, under the MIT License.
+# ------------------------------------------------------------------------------
+
+from typing import List, Optional
+
+import numpy as np
+
+
+def nms(dets: np.ndarray, thr: float) -> List[int]:
+ """Greedily select boxes with high confidence and overlap <= thr.
+
+ Args:
+ dets (np.ndarray): [[x1, y1, x2, y2, score]].
+ thr (float): Retain overlap < thr.
+
+ Returns:
+ list: Indexes to keep.
+ """
+ if len(dets) == 0:
+ return []
+
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while len(order) > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thr)[0]
+ order = order[inds + 1]
+
+ return keep
+
+
+def oks_iou(g: np.ndarray,
+ d: np.ndarray,
+ a_g: float,
+ a_d: np.ndarray,
+ sigmas: Optional[np.ndarray] = None,
+ vis_thr: Optional[float] = None) -> np.ndarray:
+ """Calculate oks ious.
+
+ Note:
+
+ - number of keypoints: K
+ - number of instances: N
+
+ Args:
+ g (np.ndarray): The instance to calculate OKS IOU with other
+ instances. Containing the keypoints coordinates. Shape: (K*3, )
+ d (np.ndarray): The rest instances. Containing the keypoints
+ coordinates. Shape: (N, K*3)
+ a_g (float): Area of the ground truth object.
+ a_d (np.ndarray): Area of the detected object. Shape: (N, )
+ sigmas (np.ndarray, optional): Keypoint labelling uncertainty.
+ Please refer to `COCO keypoint evaluation
+ `__ for more details.
+ If not given, use the sigmas on COCO dataset.
+ If specified, shape: (K, ). Defaults to ``None``
+ vis_thr(float, optional): Threshold of the keypoint visibility.
+ If specified, will calculate OKS based on those keypoints whose
+ visibility higher than vis_thr. If not given, calculate the OKS
+ based on all keypoints. Defaults to ``None``
+
+ Returns:
+ np.ndarray: The oks ious.
+ """
+ if sigmas is None:
+ sigmas = np.array([
+ .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07,
+ .87, .87, .89, .89
+ ]) / 10.0
+ vars = (sigmas * 2)**2
+ xg = g[0::3]
+ yg = g[1::3]
+ vg = g[2::3]
+ ious = np.zeros(len(d), dtype=np.float32)
+ for n_d in range(0, len(d)):
+ xd = d[n_d, 0::3]
+ yd = d[n_d, 1::3]
+ vd = d[n_d, 2::3]
+ dx = xd - xg
+ dy = yd - yg
+ e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
+ if vis_thr is not None:
+ ind = list((vg > vis_thr) & (vd > vis_thr))
+ e = e[ind]
+ ious[n_d] = np.sum(np.exp(-e)) / len(e) if len(e) != 0 else 0.0
+ return ious
+
+
+def oks_nms(kpts_db: List[dict],
+ thr: float,
+ sigmas: Optional[np.ndarray] = None,
+ vis_thr: Optional[float] = None,
+ score_per_joint: bool = False):
+ """OKS NMS implementations.
+
+ Args:
+ kpts_db (List[dict]): The keypoints results of the same image.
+ thr (float): The threshold of NMS. Will retain oks overlap < thr.
+ sigmas (np.ndarray, optional): Keypoint labelling uncertainty.
+ Please refer to `COCO keypoint evaluation
+ `__ for more details.
+ If not given, use the sigmas on COCO dataset. Defaults to ``None``
+ vis_thr(float, optional): Threshold of the keypoint visibility.
+ If specified, will calculate OKS based on those keypoints whose
+ visibility higher than vis_thr. If not given, calculate the OKS
+ based on all keypoints. Defaults to ``None``
+ score_per_joint(bool): Whether the input scores (in kpts_db) are
+ per-joint scores. Defaults to ``False``
+
+ Returns:
+ np.ndarray: indexes to keep.
+ """
+ if len(kpts_db) == 0:
+ return []
+
+ if score_per_joint:
+ scores = np.array([k['score'].mean() for k in kpts_db])
+ else:
+ scores = np.array([k['score'] for k in kpts_db])
+
+ kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
+ areas = np.array([k['area'] for k in kpts_db])
+
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while len(order) > 0:
+ i = order[0]
+ keep.append(i)
+
+ oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
+ sigmas, vis_thr)
+
+ inds = np.where(oks_ovr <= thr)[0]
+ order = order[inds + 1]
+
+ keep = np.array(keep)
+
+ return keep
+
+
+def _rescore(overlap: np.ndarray,
+ scores: np.ndarray,
+ thr: float,
+ type: str = 'gaussian'):
+ """Rescoring mechanism gaussian or linear.
+
+ Args:
+ overlap (np.ndarray): The calculated oks ious.
+ scores (np.ndarray): target scores.
+ thr (float): retain oks overlap < thr.
+ type (str): The rescoring type. Could be 'gaussian' or 'linear'.
+ Defaults to ``'gaussian'``
+
+ Returns:
+ np.ndarray: indexes to keep
+ """
+ assert len(overlap) == len(scores)
+ assert type in ['gaussian', 'linear']
+
+ if type == 'linear':
+ inds = np.where(overlap >= thr)[0]
+ scores[inds] = scores[inds] * (1 - overlap[inds])
+ else:
+ scores = scores * np.exp(-overlap**2 / thr)
+
+ return scores
+
+
+def soft_oks_nms(kpts_db: List[dict],
+ thr: float,
+ max_dets: int = 20,
+ sigmas: Optional[np.ndarray] = None,
+ vis_thr: Optional[float] = None,
+ score_per_joint: bool = False):
+ """Soft OKS NMS implementations.
+
+ Args:
+ kpts_db (List[dict]): The keypoints results of the same image.
+ thr (float): The threshold of NMS. Will retain oks overlap < thr.
+ max_dets (int): Maximum number of detections to keep. Defaults to 20
+ sigmas (np.ndarray, optional): Keypoint labelling uncertainty.
+ Please refer to `COCO keypoint evaluation
+ `__ for more details.
+ If not given, use the sigmas on COCO dataset. Defaults to ``None``
+ vis_thr(float, optional): Threshold of the keypoint visibility.
+ If specified, will calculate OKS based on those keypoints whose
+ visibility higher than vis_thr. If not given, calculate the OKS
+ based on all keypoints. Defaults to ``None``
+ score_per_joint(bool): Whether the input scores (in kpts_db) are
+ per-joint scores. Defaults to ``False``
+
+ Returns:
+ np.ndarray: indexes to keep.
+ """
+ if len(kpts_db) == 0:
+ return []
+
+ if score_per_joint:
+ scores = np.array([k['score'].mean() for k in kpts_db])
+ else:
+ scores = np.array([k['score'] for k in kpts_db])
+
+ kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
+ areas = np.array([k['area'] for k in kpts_db])
+
+ order = scores.argsort()[::-1]
+ scores = scores[order]
+
+ keep = np.zeros(max_dets, dtype=np.intp)
+ keep_cnt = 0
+ while len(order) > 0 and keep_cnt < max_dets:
+ i = order[0]
+
+ oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
+ sigmas, vis_thr)
+
+ order = order[1:]
+ scores = _rescore(oks_ovr, scores[1:], thr)
+
+ tmp = scores.argsort()[::-1]
+ order = order[tmp]
+ scores = scores[tmp]
+
+ keep[keep_cnt] = i
+ keep_cnt += 1
+
+ keep = keep[:keep_cnt]
+
+ return keep
+
+
+def nearby_joints_nms(
+ kpts_db: List[dict],
+ dist_thr: float,
+ num_nearby_joints_thr: Optional[int] = None,
+ score_per_joint: bool = False,
+ max_dets: int = 30,
+):
+ """Nearby joints NMS implementations. Instances with non-maximum scores
+ will be suppressed if they have too much closed joints with other
+ instances. This function is modified from project
+ `DEKR`.
+
+ Args:
+ kpts_db (list[dict]): keypoints and scores.
+ dist_thr (float): threshold for judging whether two joints are close.
+ num_nearby_joints_thr (int): threshold for judging whether two
+ instances are close.
+ max_dets (int): max number of detections to keep.
+ score_per_joint (bool): the input scores (in kpts_db) are per joint
+ scores.
+
+ Returns:
+ np.ndarray: indexes to keep.
+ """
+
+ assert dist_thr > 0, '`dist_thr` must be greater than 0.'
+ if len(kpts_db) == 0:
+ return []
+
+ if score_per_joint:
+ scores = np.array([k['score'].mean() for k in kpts_db])
+ else:
+ scores = np.array([k['score'] for k in kpts_db])
+
+ kpts = np.array([k['keypoints'] for k in kpts_db])
+
+ num_people, num_joints, _ = kpts.shape
+ if num_nearby_joints_thr is None:
+ num_nearby_joints_thr = num_joints // 2
+ assert num_nearby_joints_thr < num_joints, '`num_nearby_joints_thr` must '\
+ 'be less than the number of joints.'
+
+ # compute distance threshold
+ pose_area = kpts.max(axis=1) - kpts.min(axis=1)
+ pose_area = np.sqrt(np.power(pose_area, 2).sum(axis=1))
+ pose_area = pose_area.reshape(num_people, 1, 1)
+ pose_area = np.tile(pose_area, (num_people, num_joints))
+ close_dist_thr = pose_area * dist_thr
+
+ # count nearby joints between instances
+ instance_dist = kpts[:, None] - kpts
+ instance_dist = np.sqrt(np.power(instance_dist, 2).sum(axis=3))
+ close_instance_num = (instance_dist < close_dist_thr).sum(2)
+ close_instance = close_instance_num > num_nearby_joints_thr
+
+ # apply nms
+ ignored_pose_inds, keep_pose_inds = set(), list()
+ indexes = np.argsort(scores)[::-1]
+ for i in indexes:
+ if i in ignored_pose_inds:
+ continue
+ keep_inds = close_instance[i].nonzero()[0]
+ keep_ind = keep_inds[np.argmax(scores[keep_inds])]
+ if keep_ind not in ignored_pose_inds:
+ keep_pose_inds.append(keep_ind)
+ ignored_pose_inds = ignored_pose_inds.union(set(keep_inds))
+
+ # limit the number of output instances
+ if max_dets > 0 and len(keep_pose_inds) > max_dets:
+ sub_inds = np.argsort(scores[keep_pose_inds])[-1:-max_dets - 1:-1]
+ keep_pose_inds = [keep_pose_inds[i] for i in sub_inds]
+
+ return keep_pose_inds
diff --git a/mmpose/evaluation/metrics/__init__.py b/mmpose/evaluation/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02c353ef7d3d973a5ab3fa88128fa0814c6a1c7
--- /dev/null
+++ b/mmpose/evaluation/metrics/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .coco_metric import CocoMetric
+from .coco_wholebody_metric import CocoWholeBodyMetric
+from .keypoint_2d_metrics import (AUC, EPE, NME, JhmdbPCKAccuracy,
+ MpiiPCKAccuracy, PCKAccuracy)
+from .keypoint_partition_metric import KeypointPartitionMetric
+from .posetrack18_metric import PoseTrack18Metric
+
+__all__ = [
+ 'CocoMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'AUC',
+ 'EPE', 'NME', 'PoseTrack18Metric', 'CocoWholeBodyMetric',
+ 'KeypointPartitionMetric'
+]
diff --git a/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..987a926b3e38850427b129f0ea66b45fe96bae52
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2915fcb40b533eff868d71d71348e28de59fd1f5
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/coco_metric.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6127cdb481d081ab5a5e41d8d5bed1a2ace39225
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/coco_wholebody_metric.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5f9ef8c87b51e8c35d4b8849b9c2fea2c6be6c6
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/keypoint_2d_metrics.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ce8be0eca9a1d86bda8566ebdd6ff6ca18a19d0
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/keypoint_partition_metric.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc b/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66b861698ee1776f390beb79b0898666257ed862
Binary files /dev/null and b/mmpose/evaluation/metrics/__pycache__/posetrack18_metric.cpython-38.pyc differ
diff --git a/mmpose/evaluation/metrics/coco_metric.py b/mmpose/evaluation/metrics/coco_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..8327e2eca7b978f66f3329fa1f45c3619e528c89
--- /dev/null
+++ b/mmpose/evaluation/metrics/coco_metric.py
@@ -0,0 +1,550 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os.path as osp
+import tempfile
+from collections import OrderedDict, defaultdict
+from typing import Dict, Optional, Sequence
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+from mmengine.fileio import dump, get_local_path, load
+from mmengine.logging import MMLogger
+from xtcocotools.coco import COCO
+from xtcocotools.cocoeval import COCOeval
+
+from mmpose.registry import METRICS
+from ..functional import oks_nms, soft_oks_nms
+
+
+@METRICS.register_module()
+class CocoMetric(BaseMetric):
+ """COCO pose estimation task evaluation metric.
+
+ Evaluate AR, AP, and mAP for keypoint detection tasks. Support COCO
+ dataset and other datasets in COCO format. Please refer to
+ `COCO keypoint evaluation `__
+ for more details.
+
+ Args:
+ ann_file (str, optional): Path to the coco format annotation file.
+ If not specified, ground truth annotations from the dataset will
+ be converted to coco format. Defaults to None
+ use_area (bool): Whether to use ``'area'`` message in the annotations.
+ If the ground truth annotations (e.g. CrowdPose, AIC) do not have
+ the field ``'area'``, please set ``use_area=False``.
+ Defaults to ``True``
+ iou_type (str): The same parameter as `iouType` in
+ :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or
+ ``'keypoints_crowd'`` (used in CrowdPose dataset).
+ Defaults to ``'keypoints'``
+ score_mode (str): The mode to score the prediction results which
+ should be one of the following options:
+
+ - ``'bbox'``: Take the score of bbox as the score of the
+ prediction results.
+ - ``'bbox_keypoint'``: Use keypoint score to rescore the
+ prediction results.
+ - ``'bbox_rle'``: Use rle_score to rescore the
+ prediction results.
+
+ Defaults to ``'bbox_keypoint'`
+ keypoint_score_thr (float): The threshold of keypoint score. The
+ keypoints with score lower than it will not be included to
+ rescore the prediction results. Valid only when ``score_mode`` is
+ ``bbox_keypoint``. Defaults to ``0.2``
+ nms_mode (str): The mode to perform Non-Maximum Suppression (NMS),
+ which should be one of the following options:
+
+ - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to
+ perform NMS.
+ - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS)
+ to perform soft NMS.
+ - ``'none'``: Do not perform NMS. Typically for bottomup mode
+ output.
+
+ Defaults to ``'oks_nms'`
+ nms_thr (float): The Object Keypoint Similarity (OKS) threshold
+ used in NMS when ``nms_mode`` is ``'oks_nms'`` or
+ ``'soft_oks_nms'``. Will retain the prediction results with OKS
+ lower than ``nms_thr``. Defaults to ``0.9``
+ format_only (bool): Whether only format the output results without
+ doing quantitative evaluation. This is designed for the need of
+ test submission when the ground truth annotations are absent. If
+ set to ``True``, ``outfile_prefix`` should specify the path to
+ store the output results. Defaults to ``False``
+ outfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
+ If not specified, a temp file will be created. Defaults to ``None``
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Defaults to ``'cpu'``
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Defaults to ``None``
+ """
+ default_prefix: Optional[str] = 'coco'
+
+ def __init__(self,
+ ann_file: Optional[str] = None,
+ use_area: bool = True,
+ iou_type: str = 'keypoints',
+ score_mode: str = 'bbox_keypoint',
+ keypoint_score_thr: float = 0.2,
+ nms_mode: str = 'oks_nms',
+ nms_thr: float = 0.9,
+ format_only: bool = False,
+ outfile_prefix: Optional[str] = None,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.ann_file = ann_file
+ # initialize coco helper with the annotation json file
+ # if ann_file is not specified, initialize with the converted dataset
+ if ann_file is not None:
+ with get_local_path(ann_file) as local_path:
+ self.coco = COCO(local_path)
+ else:
+ self.coco = None
+
+ self.use_area = use_area
+ self.iou_type = iou_type
+
+ allowed_score_modes = ['bbox', 'bbox_keypoint', 'bbox_rle', 'keypoint']
+ if score_mode not in allowed_score_modes:
+ raise ValueError(
+ "`score_mode` should be one of 'bbox', 'bbox_keypoint', "
+ f"'bbox_rle', but got {score_mode}")
+ self.score_mode = score_mode
+ self.keypoint_score_thr = keypoint_score_thr
+
+ allowed_nms_modes = ['oks_nms', 'soft_oks_nms', 'none']
+ if nms_mode not in allowed_nms_modes:
+ raise ValueError(
+ "`nms_mode` should be one of 'oks_nms', 'soft_oks_nms', "
+ f"'none', but got {nms_mode}")
+ self.nms_mode = nms_mode
+ self.nms_thr = nms_thr
+
+ if format_only:
+ assert outfile_prefix is not None, '`outfile_prefix` can not be '\
+ 'None when `format_only` is True, otherwise the result file '\
+ 'will be saved to a temp directory which will be cleaned up '\
+ 'in the end.'
+ elif ann_file is not None:
+ # do evaluation only if the ground truth annotations exist
+ assert 'annotations' in load(ann_file), \
+ 'Ground truth annotations are required for evaluation '\
+ 'when `format_only` is False.'
+
+ self.format_only = format_only
+ self.outfile_prefix = outfile_prefix
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model, each of which has the following keys:
+
+ - 'id': The id of the sample
+ - 'img_id': The image_id of the sample
+ - 'pred_instances': The prediction results of instance(s)
+ """
+ for data_sample in data_samples:
+ if 'pred_instances' not in data_sample:
+ raise ValueError(
+ '`pred_instances` are required to process the '
+ f'predictions results in {self.__class__.__name__}. ')
+
+ # keypoints.shape: [N, K, 2],
+ # N: number of instances, K: number of keypoints
+ # for topdown-style output, N is usually 1, while for
+ # bottomup-style output, N is the number of instances in the image
+ keypoints = data_sample['pred_instances']['keypoints']
+ # [N, K], the scores for all keypoints of all instances
+ keypoint_scores = data_sample['pred_instances']['keypoint_scores']
+ assert keypoint_scores.shape == keypoints.shape[:2]
+
+ # parse prediction results
+ pred = dict()
+ pred['id'] = data_sample['id']
+ pred['img_id'] = data_sample['img_id']
+ pred['keypoints'] = keypoints
+ pred['keypoint_scores'] = keypoint_scores
+ pred['category_id'] = data_sample.get('category_id', 1)
+
+ if 'bbox_scores' in data_sample['pred_instances']:
+ # some one-stage models will predict bboxes and scores
+ # together with keypoints
+ bbox_scores = data_sample['pred_instances']['bbox_scores']
+ elif ('bbox_scores' not in data_sample['gt_instances']
+ or len(data_sample['gt_instances']['bbox_scores']) !=
+ len(keypoints)):
+ # bottom-up models might output different number of
+ # instances from annotation
+ bbox_scores = np.ones(len(keypoints))
+ else:
+ # top-down models use detected bboxes, the scores of which
+ # are contained in the gt_instances
+ bbox_scores = data_sample['gt_instances']['bbox_scores']
+ pred['bbox_scores'] = bbox_scores
+
+ # get area information
+ if 'bbox_scales' in data_sample['gt_instances']:
+ pred['areas'] = np.prod(
+ data_sample['gt_instances']['bbox_scales'], axis=1)
+
+ # parse gt
+ gt = dict()
+ if self.coco is None:
+ gt['width'] = data_sample['ori_shape'][1]
+ gt['height'] = data_sample['ori_shape'][0]
+ gt['img_id'] = data_sample['img_id']
+ if self.iou_type == 'keypoints_crowd':
+ assert 'crowd_index' in data_sample, \
+ '`crowd_index` is required when `self.iou_type` is ' \
+ '`keypoints_crowd`'
+ gt['crowd_index'] = data_sample['crowd_index']
+ assert 'raw_ann_info' in data_sample, \
+ 'The row ground truth annotations are required for ' \
+ 'evaluation when `ann_file` is not provided'
+ anns = data_sample['raw_ann_info']
+ gt['raw_ann_info'] = anns if isinstance(anns, list) else [anns]
+
+ # add converted result to the results list
+ self.results.append((pred, gt))
+
+ def gt_to_coco_json(self, gt_dicts: Sequence[dict],
+ outfile_prefix: str) -> str:
+ """Convert ground truth to coco format json file.
+
+ Args:
+ gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict
+ contains the ground truth information about the data sample.
+ Required keys of the each `gt_dict` in `gt_dicts`:
+ - `img_id`: image id of the data sample
+ - `width`: original image width
+ - `height`: original image height
+ - `raw_ann_info`: the raw annotation information
+ Optional keys:
+ - `crowd_index`: measure the crowding level of an image,
+ defined in CrowdPose dataset
+ It is worth mentioning that, in order to compute `CocoMetric`,
+ there are some required keys in the `raw_ann_info`:
+ - `id`: the id to distinguish different annotations
+ - `image_id`: the image id of this annotation
+ - `category_id`: the category of the instance.
+ - `bbox`: the object bounding box
+ - `keypoints`: the keypoints cooridinates along with their
+ visibilities. Note that it need to be aligned
+ with the official COCO format, e.g., a list with length
+ N * 3, in which N is the number of keypoints. And each
+ triplet represent the [x, y, visible] of the keypoint.
+ - `iscrowd`: indicating whether the annotation is a crowd.
+ It is useful when matching the detection results to
+ the ground truth.
+ There are some optional keys as well:
+ - `area`: it is necessary when `self.use_area` is `True`
+ - `num_keypoints`: it is necessary when `self.iou_type`
+ is set as `keypoints_crowd`.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json file will be named
+ "somepath/xxx.gt.json".
+ Returns:
+ str: The filename of the json file.
+ """
+ image_infos = []
+ annotations = []
+ img_ids = []
+ ann_ids = []
+
+ for gt_dict in gt_dicts:
+ # filter duplicate image_info
+ if gt_dict['img_id'] not in img_ids:
+ image_info = dict(
+ id=gt_dict['img_id'],
+ width=gt_dict['width'],
+ height=gt_dict['height'],
+ )
+ if self.iou_type == 'keypoints_crowd':
+ image_info['crowdIndex'] = gt_dict['crowd_index']
+
+ image_infos.append(image_info)
+ img_ids.append(gt_dict['img_id'])
+
+ # filter duplicate annotations
+ for ann in gt_dict['raw_ann_info']:
+ if ann is None:
+ # during evaluation on bottom-up datasets, some images
+ # do not have instance annotation
+ continue
+
+ annotation = dict(
+ id=ann['id'],
+ image_id=ann['image_id'],
+ category_id=ann['category_id'],
+ bbox=ann['bbox'],
+ keypoints=ann['keypoints'],
+ iscrowd=ann['iscrowd'],
+ )
+ if self.use_area:
+ assert 'area' in ann, \
+ '`area` is required when `self.use_area` is `True`'
+ annotation['area'] = ann['area']
+
+ if self.iou_type == 'keypoints_crowd':
+ assert 'num_keypoints' in ann, \
+ '`num_keypoints` is required when `self.iou_type` ' \
+ 'is `keypoints_crowd`'
+ annotation['num_keypoints'] = ann['num_keypoints']
+
+ annotations.append(annotation)
+ ann_ids.append(ann['id'])
+
+ info = dict(
+ date_created=str(datetime.datetime.now()),
+ description='Coco json file converted by mmpose CocoMetric.')
+ coco_json = dict(
+ info=info,
+ images=image_infos,
+ categories=self.dataset_meta['CLASSES'],
+ licenses=None,
+ annotations=annotations,
+ )
+ converted_json_path = f'{outfile_prefix}.gt.json'
+ dump(coco_json, converted_json_path, sort_keys=True, indent=4)
+ return converted_json_path
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # split prediction and gt list
+ preds, gts = zip(*results)
+
+ tmp_dir = None
+ if self.outfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ outfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ outfile_prefix = self.outfile_prefix
+
+ if self.coco is None:
+ # use converted gt json file to initialize coco helper
+ logger.info('Converting ground truth to coco format...')
+ coco_json_path = self.gt_to_coco_json(
+ gt_dicts=gts, outfile_prefix=outfile_prefix)
+ self.coco = COCO(coco_json_path)
+
+ kpts = defaultdict(list)
+
+ # group the preds by img_id
+ for pred in preds:
+ img_id = pred['img_id']
+ for idx in range(len(pred['keypoints'])):
+ instance = {
+ 'id': pred['id'],
+ 'img_id': pred['img_id'],
+ 'category_id': pred['category_id'],
+ 'keypoints': pred['keypoints'][idx],
+ 'keypoint_scores': pred['keypoint_scores'][idx],
+ 'bbox_score': pred['bbox_scores'][idx],
+ }
+
+ if 'areas' in pred:
+ instance['area'] = pred['areas'][idx]
+ else:
+ # use keypoint to calculate bbox and get area
+ keypoints = pred['keypoints'][idx]
+ area = (
+ np.max(keypoints[:, 0]) - np.min(keypoints[:, 0])) * (
+ np.max(keypoints[:, 1]) - np.min(keypoints[:, 1]))
+ instance['area'] = area
+
+ kpts[img_id].append(instance)
+
+ # sort keypoint results according to id and remove duplicate ones
+ kpts = self._sort_and_unique_bboxes(kpts, key='id')
+
+ # score the prediction results according to `score_mode`
+ # and perform NMS according to `nms_mode`
+ valid_kpts = defaultdict(list)
+ num_keypoints = self.dataset_meta['num_keypoints']
+ for img_id, instances in kpts.items():
+ for instance in instances:
+ # concatenate the keypoint coordinates and scores
+ instance['keypoints'] = np.concatenate([
+ instance['keypoints'], instance['keypoint_scores'][:, None]
+ ],
+ axis=-1)
+ if self.score_mode == 'bbox':
+ instance['score'] = instance['bbox_score']
+ elif self.score_mode == 'keypoint':
+ instance['score'] = np.mean(instance['keypoint_scores'])
+ else:
+ bbox_score = instance['bbox_score']
+ if self.score_mode == 'bbox_rle':
+ keypoint_scores = instance['keypoint_scores']
+ instance['score'] = float(bbox_score +
+ np.mean(keypoint_scores) +
+ np.max(keypoint_scores))
+
+ else: # self.score_mode == 'bbox_keypoint':
+ mean_kpt_score = 0
+ valid_num = 0
+ for kpt_idx in range(num_keypoints):
+ kpt_score = instance['keypoint_scores'][kpt_idx]
+ if kpt_score > self.keypoint_score_thr:
+ mean_kpt_score += kpt_score
+ valid_num += 1
+ if valid_num != 0:
+ mean_kpt_score /= valid_num
+ instance['score'] = bbox_score * mean_kpt_score
+ # perform nms
+ if self.nms_mode == 'none':
+ valid_kpts[img_id] = instances
+ else:
+ nms = oks_nms if self.nms_mode == 'oks_nms' else soft_oks_nms
+ keep = nms(
+ instances,
+ self.nms_thr,
+ sigmas=self.dataset_meta['sigmas'])
+ valid_kpts[img_id] = [instances[_keep] for _keep in keep]
+
+ # convert results to coco style and dump into a json file
+ self.results2json(valid_kpts, outfile_prefix=outfile_prefix)
+
+ # only format the results without doing quantitative evaluation
+ if self.format_only:
+ logger.info('results are saved in '
+ f'{osp.dirname(outfile_prefix)}')
+ return {}
+
+ # evaluation results
+ eval_results = OrderedDict()
+ logger.info(f'Evaluating {self.__class__.__name__}...')
+ info_str = self._do_python_keypoint_eval(outfile_prefix)
+ name_value = OrderedDict(info_str)
+ eval_results.update(name_value)
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
+
+ def results2json(self, keypoints: Dict[int, list],
+ outfile_prefix: str) -> str:
+ """Dump the keypoint detection results to a COCO style json file.
+
+ Args:
+ keypoints (Dict[int, list]): Keypoint detection results
+ of the dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json",
+
+ Returns:
+ str: The json file name of keypoint results.
+ """
+ # the results with category_id
+ cat_results = []
+
+ for _, img_kpts in keypoints.items():
+ _keypoints = np.array(
+ [img_kpt['keypoints'] for img_kpt in img_kpts])
+ num_keypoints = self.dataset_meta['num_keypoints']
+ # collect all the person keypoints in current image
+ _keypoints = _keypoints.reshape(-1, num_keypoints * 3)
+
+ result = [{
+ 'image_id': img_kpt['img_id'],
+ 'category_id': img_kpt['category_id'],
+ 'keypoints': keypoint.tolist(),
+ 'score': float(img_kpt['score']),
+ } for img_kpt, keypoint in zip(img_kpts, _keypoints)]
+
+ cat_results.extend(result)
+
+ res_file = f'{outfile_prefix}.keypoints.json'
+ dump(cat_results, res_file, sort_keys=True, indent=4)
+
+ def _do_python_keypoint_eval(self, outfile_prefix: str) -> list:
+ """Do keypoint evaluation using COCOAPI.
+
+ Args:
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json",
+
+ Returns:
+ list: a list of tuples. Each tuple contains the evaluation stats
+ name and corresponding stats value.
+ """
+ res_file = f'{outfile_prefix}.keypoints.json'
+ coco_det = self.coco.loadRes(res_file)
+ sigmas = self.dataset_meta['sigmas']
+ coco_eval = COCOeval(self.coco, coco_det, self.iou_type, sigmas,
+ self.use_area)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ if self.iou_type == 'keypoints_crowd':
+ stats_names = [
+ 'AP', 'AP .5', 'AP .75', 'AR', 'AR .5', 'AR .75', 'AP(E)',
+ 'AP(M)', 'AP(H)'
+ ]
+ else:
+ stats_names = [
+ 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
+ 'AR .75', 'AR (M)', 'AR (L)'
+ ]
+
+ info_str = list(zip(stats_names, coco_eval.stats))
+
+ return info_str
+
+ def _sort_and_unique_bboxes(self,
+ kpts: Dict[int, list],
+ key: str = 'id') -> Dict[int, list]:
+ """Sort keypoint detection results in each image and remove the
+ duplicate ones. Usually performed in multi-batch testing.
+
+ Args:
+ kpts (Dict[int, list]): keypoint prediction results. The keys are
+ '`img_id`' and the values are list that may contain
+ keypoints of multiple persons. Each element in the list is a
+ dict containing the ``'key'`` field.
+ See the argument ``key`` for details.
+ key (str): The key name in each person prediction results. The
+ corresponding value will be used for sorting the results.
+ Default: ``'id'``.
+
+ Returns:
+ Dict[int, list]: The sorted keypoint detection results.
+ """
+ for img_id, persons in kpts.items():
+ # deal with bottomup-style output
+ if isinstance(kpts[img_id][0][key], Sequence):
+ return kpts
+ num = len(persons)
+ kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key])
+ for i in range(num - 1, 0, -1):
+ if kpts[img_id][i][key] == kpts[img_id][i - 1][key]:
+ del kpts[img_id][i]
+
+ return kpts
diff --git a/mmpose/evaluation/metrics/coco_wholebody_metric.py b/mmpose/evaluation/metrics/coco_wholebody_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5675f54c8fe793f3f5c1a30284dc2ab44b28ccd
--- /dev/null
+++ b/mmpose/evaluation/metrics/coco_wholebody_metric.py
@@ -0,0 +1,312 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+from typing import Dict, Optional, Sequence
+
+import numpy as np
+from mmengine.fileio import dump
+from xtcocotools.cocoeval import COCOeval
+
+from mmpose.registry import METRICS
+from .coco_metric import CocoMetric
+
+
+@METRICS.register_module()
+class CocoWholeBodyMetric(CocoMetric):
+ """COCO-WholeBody evaluation metric.
+
+ Evaluate AR, AP, and mAP for COCO-WholeBody keypoint detection tasks.
+ Support COCO-WholeBody dataset. Please refer to
+ `COCO keypoint evaluation `__
+ for more details.
+
+ Args:
+ ann_file (str, optional): Path to the coco format annotation file.
+ If not specified, ground truth annotations from the dataset will
+ be converted to coco format. Defaults to None
+ use_area (bool): Whether to use ``'area'`` message in the annotations.
+ If the ground truth annotations (e.g. CrowdPose, AIC) do not have
+ the field ``'area'``, please set ``use_area=False``.
+ Defaults to ``True``
+ iou_type (str): The same parameter as `iouType` in
+ :class:`xtcocotools.COCOeval`, which can be ``'keypoints'``, or
+ ``'keypoints_crowd'`` (used in CrowdPose dataset).
+ Defaults to ``'keypoints'``
+ score_mode (str): The mode to score the prediction results which
+ should be one of the following options:
+
+ - ``'bbox'``: Take the score of bbox as the score of the
+ prediction results.
+ - ``'bbox_keypoint'``: Use keypoint score to rescore the
+ prediction results.
+ - ``'bbox_rle'``: Use rle_score to rescore the
+ prediction results.
+
+ Defaults to ``'bbox_keypoint'`
+ keypoint_score_thr (float): The threshold of keypoint score. The
+ keypoints with score lower than it will not be included to
+ rescore the prediction results. Valid only when ``score_mode`` is
+ ``bbox_keypoint``. Defaults to ``0.2``
+ nms_mode (str): The mode to perform Non-Maximum Suppression (NMS),
+ which should be one of the following options:
+
+ - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to
+ perform NMS.
+ - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS)
+ to perform soft NMS.
+ - ``'none'``: Do not perform NMS. Typically for bottomup mode
+ output.
+
+ Defaults to ``'oks_nms'`
+ nms_thr (float): The Object Keypoint Similarity (OKS) threshold
+ used in NMS when ``nms_mode`` is ``'oks_nms'`` or
+ ``'soft_oks_nms'``. Will retain the prediction results with OKS
+ lower than ``nms_thr``. Defaults to ``0.9``
+ format_only (bool): Whether only format the output results without
+ doing quantitative evaluation. This is designed for the need of
+ test submission when the ground truth annotations are absent. If
+ set to ``True``, ``outfile_prefix`` should specify the path to
+ store the output results. Defaults to ``False``
+ outfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
+ If not specified, a temp file will be created. Defaults to ``None``
+ **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric`
+ """
+ default_prefix: Optional[str] = 'coco-wholebody'
+ body_num = 17
+ foot_num = 6
+ face_num = 68
+ left_hand_num = 21
+ right_hand_num = 21
+
+ def gt_to_coco_json(self, gt_dicts: Sequence[dict],
+ outfile_prefix: str) -> str:
+ """Convert ground truth to coco format json file.
+
+ Args:
+ gt_dicts (Sequence[dict]): Ground truth of the dataset. Each dict
+ contains the ground truth information about the data sample.
+ Required keys of the each `gt_dict` in `gt_dicts`:
+ - `img_id`: image id of the data sample
+ - `width`: original image width
+ - `height`: original image height
+ - `raw_ann_info`: the raw annotation information
+ Optional keys:
+ - `crowd_index`: measure the crowding level of an image,
+ defined in CrowdPose dataset
+ It is worth mentioning that, in order to compute `CocoMetric`,
+ there are some required keys in the `raw_ann_info`:
+ - `id`: the id to distinguish different annotations
+ - `image_id`: the image id of this annotation
+ - `category_id`: the category of the instance.
+ - `bbox`: the object bounding box
+ - `keypoints`: the keypoints cooridinates along with their
+ visibilities. Note that it need to be aligned
+ with the official COCO format, e.g., a list with length
+ N * 3, in which N is the number of keypoints. And each
+ triplet represent the [x, y, visible] of the keypoint.
+ - 'keypoints'
+ - `iscrowd`: indicating whether the annotation is a crowd.
+ It is useful when matching the detection results to
+ the ground truth.
+ There are some optional keys as well:
+ - `area`: it is necessary when `self.use_area` is `True`
+ - `num_keypoints`: it is necessary when `self.iou_type`
+ is set as `keypoints_crowd`.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json file will be named
+ "somepath/xxx.gt.json".
+ Returns:
+ str: The filename of the json file.
+ """
+ image_infos = []
+ annotations = []
+ img_ids = []
+ ann_ids = []
+
+ for gt_dict in gt_dicts:
+ # filter duplicate image_info
+ if gt_dict['img_id'] not in img_ids:
+ image_info = dict(
+ id=gt_dict['img_id'],
+ width=gt_dict['width'],
+ height=gt_dict['height'],
+ )
+ if self.iou_type == 'keypoints_crowd':
+ image_info['crowdIndex'] = gt_dict['crowd_index']
+
+ image_infos.append(image_info)
+ img_ids.append(gt_dict['img_id'])
+
+ # filter duplicate annotations
+ for ann in gt_dict['raw_ann_info']:
+ annotation = dict(
+ id=ann['id'],
+ image_id=ann['image_id'],
+ category_id=ann['category_id'],
+ bbox=ann['bbox'],
+ keypoints=ann['keypoints'],
+ foot_kpts=ann['foot_kpts'],
+ face_kpts=ann['face_kpts'],
+ lefthand_kpts=ann['lefthand_kpts'],
+ righthand_kpts=ann['righthand_kpts'],
+ iscrowd=ann['iscrowd'],
+ )
+ if self.use_area:
+ assert 'area' in ann, \
+ '`area` is required when `self.use_area` is `True`'
+ annotation['area'] = ann['area']
+
+ annotations.append(annotation)
+ ann_ids.append(ann['id'])
+
+ info = dict(
+ date_created=str(datetime.datetime.now()),
+ description='Coco json file converted by mmpose CocoMetric.')
+ coco_json: dict = dict(
+ info=info,
+ images=image_infos,
+ categories=self.dataset_meta['CLASSES'],
+ licenses=None,
+ annotations=annotations,
+ )
+ converted_json_path = f'{outfile_prefix}.gt.json'
+ dump(coco_json, converted_json_path, sort_keys=True, indent=4)
+ return converted_json_path
+
+ def results2json(self, keypoints: Dict[int, list],
+ outfile_prefix: str) -> str:
+ """Dump the keypoint detection results to a COCO style json file.
+
+ Args:
+ keypoints (Dict[int, list]): Keypoint detection results
+ of the dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json",
+
+ Returns:
+ str: The json file name of keypoint results.
+ """
+ # the results with category_id
+ cat_id = 1
+ cat_results = []
+
+ cuts = np.cumsum([
+ 0, self.body_num, self.foot_num, self.face_num, self.left_hand_num,
+ self.right_hand_num
+ ]) * 3
+
+ for _, img_kpts in keypoints.items():
+ _keypoints = np.array(
+ [img_kpt['keypoints'] for img_kpt in img_kpts])
+ num_keypoints = self.dataset_meta['num_keypoints']
+ # collect all the person keypoints in current image
+ _keypoints = _keypoints.reshape(-1, num_keypoints * 3)
+
+ result = [{
+ 'image_id': img_kpt['img_id'],
+ 'category_id': cat_id,
+ 'keypoints': _keypoint[cuts[0]:cuts[1]].tolist(),
+ 'foot_kpts': _keypoint[cuts[1]:cuts[2]].tolist(),
+ 'face_kpts': _keypoint[cuts[2]:cuts[3]].tolist(),
+ 'lefthand_kpts': _keypoint[cuts[3]:cuts[4]].tolist(),
+ 'righthand_kpts': _keypoint[cuts[4]:cuts[5]].tolist(),
+ 'score': float(img_kpt['score']),
+ } for img_kpt, _keypoint in zip(img_kpts, _keypoints)]
+
+ cat_results.extend(result)
+
+ res_file = f'{outfile_prefix}.keypoints.json'
+ dump(cat_results, res_file, sort_keys=True, indent=4)
+
+ def _do_python_keypoint_eval(self, outfile_prefix: str) -> list:
+ """Do keypoint evaluation using COCOAPI.
+
+ Args:
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json",
+
+ Returns:
+ list: a list of tuples. Each tuple contains the evaluation stats
+ name and corresponding stats value.
+ """
+ res_file = f'{outfile_prefix}.keypoints.json'
+ coco_det = self.coco.loadRes(res_file)
+ sigmas = self.dataset_meta['sigmas']
+
+ cuts = np.cumsum([
+ 0, self.body_num, self.foot_num, self.face_num, self.left_hand_num,
+ self.right_hand_num
+ ])
+
+ coco_eval = COCOeval(
+ self.coco,
+ coco_det,
+ 'keypoints_body',
+ sigmas[cuts[0]:cuts[1]],
+ use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ coco_eval = COCOeval(
+ self.coco,
+ coco_det,
+ 'keypoints_foot',
+ sigmas[cuts[1]:cuts[2]],
+ use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ coco_eval = COCOeval(
+ self.coco,
+ coco_det,
+ 'keypoints_face',
+ sigmas[cuts[2]:cuts[3]],
+ use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ coco_eval = COCOeval(
+ self.coco,
+ coco_det,
+ 'keypoints_lefthand',
+ sigmas[cuts[3]:cuts[4]],
+ use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ coco_eval = COCOeval(
+ self.coco,
+ coco_det,
+ 'keypoints_righthand',
+ sigmas[cuts[4]:cuts[5]],
+ use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ coco_eval = COCOeval(
+ self.coco, coco_det, 'keypoints_wholebody', sigmas, use_area=True)
+ coco_eval.params.useSegm = None
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ stats_names = [
+ 'AP', 'AP .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
+ 'AR .75', 'AR (M)', 'AR (L)'
+ ]
+
+ info_str = list(zip(stats_names, coco_eval.stats))
+
+ return info_str
diff --git a/mmpose/evaluation/metrics/keypoint_2d_metrics.py b/mmpose/evaluation/metrics/keypoint_2d_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a63f1e5193872493fc32ec852a3a31dafa17cd
--- /dev/null
+++ b/mmpose/evaluation/metrics/keypoint_2d_metrics.py
@@ -0,0 +1,910 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Dict, Optional, Sequence, Union
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+from mmengine.logging import MMLogger
+
+from mmpose.registry import METRICS
+from ..functional import (keypoint_auc, keypoint_epe, keypoint_nme,
+ keypoint_pck_accuracy)
+
+
+@METRICS.register_module()
+class PCKAccuracy(BaseMetric):
+ """PCK accuracy evaluation metric.
+ Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for
+ each individual keypoint and the averaged accuracy across all keypoints.
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the person bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+ Args:
+ thr(float): Threshold of PCK calculation. Default: 0.05.
+ norm_item (str | Sequence[str]): The item used for normalization.
+ Valid items include 'bbox', 'head', 'torso', which correspond
+ to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'bbox'``.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+
+ Examples:
+
+ >>> from mmpose.evaluation.metrics import PCKAccuracy
+ >>> import numpy as np
+ >>> from mmengine.structures import InstanceData
+ >>> num_keypoints = 15
+ >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10
+ >>> gt_instances = InstanceData()
+ >>> gt_instances.keypoints = keypoints
+ >>> gt_instances.keypoints_visible = np.ones(
+ ... (1, num_keypoints, 1)).astype(bool)
+ >>> gt_instances.bboxes = np.random.random((1, 4)) * 20
+ >>> pred_instances = InstanceData()
+ >>> pred_instances.keypoints = keypoints
+ >>> data_sample = {
+ ... 'gt_instances': gt_instances.to_dict(),
+ ... 'pred_instances': pred_instances.to_dict(),
+ ... }
+ >>> data_samples = [data_sample]
+ >>> data_batch = [{'inputs': None}]
+ >>> pck_metric = PCKAccuracy(thr=0.5, norm_item='bbox')
+ ...: UserWarning: The prefix is not set in metric class PCKAccuracy.
+ >>> pck_metric.process(data_batch, data_samples)
+ >>> pck_metric.evaluate(1)
+ 10/26 15:37:57 - mmengine - INFO - Evaluating PCKAccuracy (normalized by ``"bbox_size"``)... # noqa
+ {'PCK': 1.0}
+
+ """
+
+ def __init__(self,
+ thr: float = 0.05,
+ norm_item: Union[str, Sequence[str]] = 'bbox',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.thr = thr
+ self.norm_item = norm_item if isinstance(norm_item,
+ (tuple,
+ list)) else [norm_item]
+ allow_normalized_items = ['bbox', 'head', 'torso']
+ for item in self.norm_item:
+ if item not in allow_normalized_items:
+ raise KeyError(
+ f'The normalized item {item} is not supported by '
+ f"{self.__class__.__name__}. Should be one of 'bbox', "
+ f"'head', 'torso', but got {item}.")
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions.
+
+ The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ # predicted keypoints coordinates, [1, K, D]
+ pred_coords = data_sample['pred_instances']['keypoints']
+ # ground truth data_info
+ gt = data_sample['gt_instances']
+ # ground truth keypoints coordinates, [1, K, D]
+ gt_coords = gt['keypoints']
+ # ground truth keypoints_visible, [1, K, 1]
+ mask = gt['keypoints_visible'].astype(bool).reshape(1, -1)
+
+ result = {
+ 'pred_coords': pred_coords,
+ 'gt_coords': gt_coords,
+ 'mask': mask,
+ }
+
+ if 'bbox' in self.norm_item:
+ assert 'bboxes' in gt, 'The ground truth data info do not ' \
+ 'have the expected normalized_item ``"bbox"``.'
+ # ground truth bboxes, [1, 4]
+ bbox_size_ = np.max(gt['bboxes'][0][2:] - gt['bboxes'][0][:2])
+ bbox_size = np.array([bbox_size_, bbox_size_]).reshape(-1, 2)
+ result['bbox_size'] = bbox_size
+
+ if 'head' in self.norm_item:
+ assert 'head_size' in gt, 'The ground truth data info do ' \
+ 'not have the expected normalized_item ``"head_size"``.'
+ # ground truth bboxes
+ head_size_ = gt['head_size']
+ head_size = np.array([head_size_, head_size_]).reshape(-1, 2)
+ result['head_size'] = head_size
+
+ if 'torso' in self.norm_item:
+ # used in JhmdbDataset
+ torso_size_ = np.linalg.norm(gt_coords[0][4] - gt_coords[0][5])
+ if torso_size_ < 1:
+ torso_size_ = np.linalg.norm(pred_coords[0][4] -
+ pred_coords[0][5])
+ warnings.warn('Ground truth torso size < 1. '
+ 'Use torso size from predicted '
+ 'keypoint results instead.')
+ torso_size = np.array([torso_size_,
+ torso_size_]).reshape(-1, 2)
+ result['torso_size'] = torso_size
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ The returned result dict may have the following keys:
+ - 'PCK': The pck accuracy normalized by `bbox_size`.
+ - 'PCKh': The pck accuracy normalized by `head_size`.
+ - 'tPCK': The pck accuracy normalized by `torso_size`.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ metrics = dict()
+ if 'bbox' in self.norm_item:
+ norm_size_bbox = np.concatenate(
+ [result['bbox_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"bbox_size"``)...')
+
+ _, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_bbox)
+ metrics['PCK'] = pck
+
+ if 'head' in self.norm_item:
+ norm_size_head = np.concatenate(
+ [result['head_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"head_size"``)...')
+
+ _, pckh, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_head)
+ metrics['PCKh'] = pckh
+
+ if 'torso' in self.norm_item:
+ norm_size_torso = np.concatenate(
+ [result['torso_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"torso_size"``)...')
+
+ _, tpck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_torso)
+ metrics['tPCK'] = tpck
+
+ return metrics
+
+
+@METRICS.register_module()
+class MpiiPCKAccuracy(PCKAccuracy):
+ """PCKh accuracy evaluation metric for MPII dataset.
+
+ Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for
+ each individual keypoint and the averaged accuracy across all keypoints.
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the person bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ thr(float): Threshold of PCK calculation. Default: 0.05.
+ norm_item (str | Sequence[str]): The item used for normalization.
+ Valid items include 'bbox', 'head', 'torso', which correspond
+ to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'head'``.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+
+ Examples:
+
+ >>> from mmpose.evaluation.metrics import MpiiPCKAccuracy
+ >>> import numpy as np
+ >>> from mmengine.structures import InstanceData
+ >>> num_keypoints = 16
+ >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10
+ >>> gt_instances = InstanceData()
+ >>> gt_instances.keypoints = keypoints + 1.0
+ >>> gt_instances.keypoints_visible = np.ones(
+ ... (1, num_keypoints, 1)).astype(bool)
+ >>> gt_instances.head_size = np.random.random((1, 1)) * 10
+ >>> pred_instances = InstanceData()
+ >>> pred_instances.keypoints = keypoints
+ >>> data_sample = {
+ ... 'gt_instances': gt_instances.to_dict(),
+ ... 'pred_instances': pred_instances.to_dict(),
+ ... }
+ >>> data_samples = [data_sample]
+ >>> data_batch = [{'inputs': None}]
+ >>> mpii_pck_metric = MpiiPCKAccuracy(thr=0.3, norm_item='head')
+ ... UserWarning: The prefix is not set in metric class MpiiPCKAccuracy.
+ >>> mpii_pck_metric.process(data_batch, data_samples)
+ >>> mpii_pck_metric.evaluate(1)
+ 10/26 17:43:39 - mmengine - INFO - Evaluating MpiiPCKAccuracy (normalized by ``"head_size"``)... # noqa
+ {'Head PCK': 100.0, 'Shoulder PCK': 100.0, 'Elbow PCK': 100.0,
+ Wrist PCK': 100.0, 'Hip PCK': 100.0, 'Knee PCK': 100.0,
+ 'Ankle PCK': 100.0, 'PCK': 100.0, 'PCK@0.1': 100.0}
+ """
+
+ def __init__(self,
+ thr: float = 0.5,
+ norm_item: Union[str, Sequence[str]] = 'head',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(
+ thr=thr,
+ norm_item=norm_item,
+ collect_device=collect_device,
+ prefix=prefix)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ If `'head'` in `self.norm_item`, the returned results are the pck
+ accuracy normalized by `head_size`, which have the following keys:
+ - 'Head PCK': The PCK of head
+ - 'Shoulder PCK': The PCK of shoulder
+ - 'Elbow PCK': The PCK of elbow
+ - 'Wrist PCK': The PCK of wrist
+ - 'Hip PCK': The PCK of hip
+ - 'Knee PCK': The PCK of knee
+ - 'Ankle PCK': The PCK of ankle
+ - 'PCK': The mean PCK over all keypoints
+ - 'PCK@0.1': The mean PCK at threshold 0.1
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ # MPII uses matlab format, gt index is 1-based,
+ # convert 0-based index to 1-based index
+ pred_coords = pred_coords + 1.0
+
+ metrics = {}
+ if 'head' in self.norm_item:
+ norm_size_head = np.concatenate(
+ [result['head_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"head_size"``)...')
+
+ pck_p, _, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_head)
+
+ jnt_count = np.sum(mask, axis=0)
+ PCKh = 100. * pck_p
+
+ rng = np.arange(0, 0.5 + 0.01, 0.01)
+ pckAll = np.zeros((len(rng), 16), dtype=np.float32)
+
+ for r, threshold in enumerate(rng):
+ _pck, _, _ = keypoint_pck_accuracy(pred_coords, gt_coords,
+ mask, threshold,
+ norm_size_head)
+ pckAll[r, :] = 100. * _pck
+
+ PCKh = np.ma.array(PCKh, mask=False)
+ PCKh.mask[6:8] = True
+
+ jnt_count = np.ma.array(jnt_count, mask=False)
+ jnt_count.mask[6:8] = True
+ jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
+
+ # dataset_joints_idx:
+ # head 9
+ # lsho 13 rsho 12
+ # lelb 14 relb 11
+ # lwri 15 rwri 10
+ # lhip 3 rhip 2
+ # lkne 4 rkne 1
+ # lank 5 rank 0
+ stats = {
+ 'Head PCK': PCKh[9],
+ 'Shoulder PCK': 0.5 * (PCKh[13] + PCKh[12]),
+ 'Elbow PCK': 0.5 * (PCKh[14] + PCKh[11]),
+ 'Wrist PCK': 0.5 * (PCKh[15] + PCKh[10]),
+ 'Hip PCK': 0.5 * (PCKh[3] + PCKh[2]),
+ 'Knee PCK': 0.5 * (PCKh[4] + PCKh[1]),
+ 'Ankle PCK': 0.5 * (PCKh[5] + PCKh[0]),
+ 'PCK': np.sum(PCKh * jnt_ratio),
+ 'PCK@0.1': np.sum(pckAll[10, :] * jnt_ratio)
+ }
+
+ for stats_name, stat in stats.items():
+ metrics[stats_name] = stat
+
+ return metrics
+
+
+@METRICS.register_module()
+class JhmdbPCKAccuracy(PCKAccuracy):
+ """PCK accuracy evaluation metric for Jhmdb dataset.
+
+ Calculate the pose accuracy of Percentage of Correct Keypoints (PCK) for
+ each individual keypoint and the averaged accuracy across all keypoints.
+ PCK metric measures accuracy of the localization of the body joints.
+ The distances between predicted positions and the ground-truth ones
+ are typically normalized by the person bounding box size.
+ The threshold (thr) of the normalized distance is commonly set
+ as 0.05, 0.1 or 0.2 etc.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ thr(float): Threshold of PCK calculation. Default: 0.05.
+ norm_item (str | Sequence[str]): The item used for normalization.
+ Valid items include 'bbox', 'head', 'torso', which correspond
+ to 'PCK', 'PCKh' and 'tPCK' respectively. Default: ``'bbox'``.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+
+ Examples:
+
+ >>> from mmpose.evaluation.metrics import JhmdbPCKAccuracy
+ >>> import numpy as np
+ >>> from mmengine.structures import InstanceData
+ >>> num_keypoints = 15
+ >>> keypoints = np.random.random((1, num_keypoints, 2)) * 10
+ >>> gt_instances = InstanceData()
+ >>> gt_instances.keypoints = keypoints
+ >>> gt_instances.keypoints_visible = np.ones(
+ ... (1, num_keypoints, 1)).astype(bool)
+ >>> gt_instances.bboxes = np.random.random((1, 4)) * 20
+ >>> gt_instances.head_size = np.random.random((1, 1)) * 10
+ >>> pred_instances = InstanceData()
+ >>> pred_instances.keypoints = keypoints
+ >>> data_sample = {
+ ... 'gt_instances': gt_instances.to_dict(),
+ ... 'pred_instances': pred_instances.to_dict(),
+ ... }
+ >>> data_samples = [data_sample]
+ >>> data_batch = [{'inputs': None}]
+ >>> jhmdb_pck_metric = JhmdbPCKAccuracy(thr=0.2, norm_item=['bbox', 'torso'])
+ ... UserWarning: The prefix is not set in metric class JhmdbPCKAccuracy.
+ >>> jhmdb_pck_metric.process(data_batch, data_samples)
+ >>> jhmdb_pck_metric.evaluate(1)
+ 10/26 17:48:09 - mmengine - INFO - Evaluating JhmdbPCKAccuracy (normalized by ``"bbox_size"``)... # noqa
+ 10/26 17:48:09 - mmengine - INFO - Evaluating JhmdbPCKAccuracy (normalized by ``"torso_size"``)... # noqa
+ {'Head PCK': 1.0, 'Sho PCK': 1.0, 'Elb PCK': 1.0, 'Wri PCK': 1.0,
+ 'Hip PCK': 1.0, 'Knee PCK': 1.0, 'Ank PCK': 1.0, 'PCK': 1.0,
+ 'Head tPCK': 1.0, 'Sho tPCK': 1.0, 'Elb tPCK': 1.0, 'Wri tPCK': 1.0,
+ 'Hip tPCK': 1.0, 'Knee tPCK': 1.0, 'Ank tPCK': 1.0, 'tPCK': 1.0}
+ """
+
+ def __init__(self,
+ thr: float = 0.05,
+ norm_item: Union[str, Sequence[str]] = 'bbox',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(
+ thr=thr,
+ norm_item=norm_item,
+ collect_device=collect_device,
+ prefix=prefix)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ If `'bbox'` in `self.norm_item`, the returned results are the pck
+ accuracy normalized by `bbox_size`, which have the following keys:
+ - 'Head PCK': The PCK of head
+ - 'Sho PCK': The PCK of shoulder
+ - 'Elb PCK': The PCK of elbow
+ - 'Wri PCK': The PCK of wrist
+ - 'Hip PCK': The PCK of hip
+ - 'Knee PCK': The PCK of knee
+ - 'Ank PCK': The PCK of ankle
+ - 'PCK': The mean PCK over all keypoints
+ If `'torso'` in `self.norm_item`, the returned results are the pck
+ accuracy normalized by `torso_size`, which have the following keys:
+ - 'Head tPCK': The PCK of head
+ - 'Sho tPCK': The PCK of shoulder
+ - 'Elb tPCK': The PCK of elbow
+ - 'Wri tPCK': The PCK of wrist
+ - 'Hip tPCK': The PCK of hip
+ - 'Knee tPCK': The PCK of knee
+ - 'Ank tPCK': The PCK of ankle
+ - 'tPCK': The mean PCK over all keypoints
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ metrics = dict()
+ if 'bbox' in self.norm_item:
+ norm_size_bbox = np.concatenate(
+ [result['bbox_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"bbox_size"``)...')
+
+ pck_p, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_bbox)
+ stats = {
+ 'Head PCK': pck_p[2],
+ 'Sho PCK': 0.5 * pck_p[3] + 0.5 * pck_p[4],
+ 'Elb PCK': 0.5 * pck_p[7] + 0.5 * pck_p[8],
+ 'Wri PCK': 0.5 * pck_p[11] + 0.5 * pck_p[12],
+ 'Hip PCK': 0.5 * pck_p[5] + 0.5 * pck_p[6],
+ 'Knee PCK': 0.5 * pck_p[9] + 0.5 * pck_p[10],
+ 'Ank PCK': 0.5 * pck_p[13] + 0.5 * pck_p[14],
+ 'PCK': pck
+ }
+
+ for stats_name, stat in stats.items():
+ metrics[stats_name] = stat
+
+ if 'torso' in self.norm_item:
+ norm_size_torso = np.concatenate(
+ [result['torso_size'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__} '
+ f'(normalized by ``"torso_size"``)...')
+
+ pck_p, pck, _ = keypoint_pck_accuracy(pred_coords, gt_coords, mask,
+ self.thr, norm_size_torso)
+
+ stats = {
+ 'Head tPCK': pck_p[2],
+ 'Sho tPCK': 0.5 * pck_p[3] + 0.5 * pck_p[4],
+ 'Elb tPCK': 0.5 * pck_p[7] + 0.5 * pck_p[8],
+ 'Wri tPCK': 0.5 * pck_p[11] + 0.5 * pck_p[12],
+ 'Hip tPCK': 0.5 * pck_p[5] + 0.5 * pck_p[6],
+ 'Knee tPCK': 0.5 * pck_p[9] + 0.5 * pck_p[10],
+ 'Ank tPCK': 0.5 * pck_p[13] + 0.5 * pck_p[14],
+ 'tPCK': pck
+ }
+
+ for stats_name, stat in stats.items():
+ metrics[stats_name] = stat
+
+ return metrics
+
+
+@METRICS.register_module()
+class AUC(BaseMetric):
+ """AUC evaluation metric.
+
+ Calculate the Area Under Curve (AUC) of keypoint PCK accuracy.
+
+ By altering the threshold percentage in the calculation of PCK accuracy,
+ AUC can be generated to further evaluate the pose estimation algorithms.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ norm_factor (float): AUC normalization factor, Default: 30 (pixels).
+ num_thrs (int): number of thresholds to calculate auc. Default: 20.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+ """
+
+ def __init__(self,
+ norm_factor: float = 30,
+ num_thrs: int = 20,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.norm_factor = norm_factor
+ self.num_thrs = num_thrs
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_sample (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ # predicted keypoints coordinates, [1, K, D]
+ pred_coords = data_sample['pred_instances']['keypoints']
+ # ground truth data_info
+ gt = data_sample['gt_instances']
+ # ground truth keypoints coordinates, [1, K, D]
+ gt_coords = gt['keypoints']
+ # ground truth keypoints_visible, [1, K, 1]
+ mask = gt['keypoints_visible'].astype(bool).reshape(1, -1)
+
+ result = {
+ 'pred_coords': pred_coords,
+ 'gt_coords': gt_coords,
+ 'mask': mask,
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__}...')
+
+ auc = keypoint_auc(pred_coords, gt_coords, mask, self.norm_factor,
+ self.num_thrs)
+
+ metrics = dict()
+ metrics['AUC'] = auc
+
+ return metrics
+
+
+@METRICS.register_module()
+class EPE(BaseMetric):
+ """EPE evaluation metric.
+
+ Calculate the end-point error (EPE) of keypoints.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+ """
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ # predicted keypoints coordinates, [1, K, D]
+ pred_coords = data_sample['pred_instances']['keypoints']
+ # ground truth data_info
+ gt = data_sample['gt_instances']
+ # ground truth keypoints coordinates, [1, K, D]
+ gt_coords = gt['keypoints']
+ # ground truth keypoints_visible, [1, K, 1]
+ mask = gt['keypoints_visible'].astype(bool).reshape(1, -1)
+
+ result = {
+ 'pred_coords': pred_coords,
+ 'gt_coords': gt_coords,
+ 'mask': mask,
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__}...')
+
+ epe = keypoint_epe(pred_coords, gt_coords, mask)
+
+ metrics = dict()
+ metrics['EPE'] = epe
+
+ return metrics
+
+
+@METRICS.register_module()
+class NME(BaseMetric):
+ """NME evaluation metric.
+
+ Calculate the normalized mean error (NME) of keypoints.
+
+ Note:
+ - length of dataset: N
+ - num_keypoints: K
+ - number of keypoint dimensions: D (typically D = 2)
+
+ Args:
+ norm_mode (str): The normalization mode. There are two valid modes:
+ `'use_norm_item'` and `'keypoint_distance'`.
+ When set as `'use_norm_item'`, should specify the argument
+ `norm_item`, which represents the item in the datainfo that
+ will be used as the normalization factor.
+ When set as `'keypoint_distance'`, should specify the argument
+ `keypoint_indices` that are used to calculate the keypoint
+ distance as the normalization factor.
+ norm_item (str, optional): The item used as the normalization factor.
+ For example, `'bbox_size'` in `'AFLWDataset'`. Only valid when
+ ``norm_mode`` is ``use_norm_item``.
+ Default: ``None``.
+ keypoint_indices (Sequence[int], optional): The keypoint indices used
+ to calculate the keypoint distance as the normalization factor.
+ Only valid when ``norm_mode`` is ``keypoint_distance``.
+ If set as None, will use the default ``keypoint_indices`` in
+ `DEFAULT_KEYPOINT_INDICES` for specific datasets, else use the
+ given ``keypoint_indices`` of the dataset. Default: ``None``.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be ``'cpu'`` or
+ ``'gpu'``. Default: ``'cpu'``.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, ``self.default_prefix``
+ will be used instead. Default: ``None``.
+ """
+
+ DEFAULT_KEYPOINT_INDICES = {
+ # horse10: corresponding to `nose` and `eye` keypoints
+ 'horse10': [0, 1],
+ # 300w: corresponding to `right-most` and `left-most` eye keypoints
+ '300w': [36, 45],
+ # coco_wholebody_face corresponding to `right-most` and `left-most`
+ # eye keypoints
+ 'coco_wholebody_face': [36, 45],
+ # cofw: corresponding to `right-most` and `left-most` eye keypoints
+ 'cofw': [8, 9],
+ # wflw: corresponding to `right-most` and `left-most` eye keypoints
+ 'wflw': [60, 72],
+ }
+
+ def __init__(self,
+ norm_mode: str,
+ norm_item: Optional[str] = None,
+ keypoint_indices: Optional[Sequence[int]] = None,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ allowed_norm_modes = ['use_norm_item', 'keypoint_distance']
+ if norm_mode not in allowed_norm_modes:
+ raise KeyError("`norm_mode` should be 'use_norm_item' or "
+ f"'keypoint_distance', but got {norm_mode}.")
+
+ self.norm_mode = norm_mode
+ if self.norm_mode == 'use_norm_item':
+ if not norm_item:
+ raise KeyError('`norm_mode` is set to `"use_norm_item"`, '
+ 'please specify the `norm_item` in the '
+ 'datainfo used as the normalization factor.')
+ self.norm_item = norm_item
+ self.keypoint_indices = keypoint_indices
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples and predictions. The processed
+ results should be stored in ``self.results``, which will be used to
+ compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data
+ from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from
+ the model.
+ """
+ for data_sample in data_samples:
+ # predicted keypoints coordinates, [1, K, D]
+ pred_coords = data_sample['pred_instances']['keypoints']
+ # ground truth data_info
+ gt = data_sample['gt_instances']
+ # ground truth keypoints coordinates, [1, K, D]
+ gt_coords = gt['keypoints']
+ # ground truth keypoints_visible, [1, K, 1]
+ mask = gt['keypoints_visible'].astype(bool).reshape(1, -1)
+
+ result = {
+ 'pred_coords': pred_coords,
+ 'gt_coords': gt_coords,
+ 'mask': mask,
+ }
+
+ if self.norm_item:
+ if self.norm_item == 'bbox_size':
+ assert 'bboxes' in gt, 'The ground truth data info do ' \
+ 'not have the item ``bboxes`` for expected ' \
+ 'normalized_item ``"bbox_size"``.'
+ # ground truth bboxes, [1, 4]
+ bbox_size = np.max(gt['bboxes'][0][2:] -
+ gt['bboxes'][0][:2])
+ result['bbox_size'] = np.array([bbox_size]).reshape(-1, 1)
+ else:
+ assert self.norm_item in gt, f'The ground truth data ' \
+ f'info do not have the expected normalized factor ' \
+ f'"{self.norm_item}"'
+ # ground truth norm_item
+ result[self.norm_item] = np.array(
+ gt[self.norm_item]).reshape([-1, 1])
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # pred_coords: [N, K, D]
+ pred_coords = np.concatenate(
+ [result['pred_coords'] for result in results])
+ # gt_coords: [N, K, D]
+ gt_coords = np.concatenate([result['gt_coords'] for result in results])
+ # mask: [N, K]
+ mask = np.concatenate([result['mask'] for result in results])
+
+ logger.info(f'Evaluating {self.__class__.__name__}...')
+ metrics = dict()
+
+ if self.norm_mode == 'use_norm_item':
+ normalize_factor_ = np.concatenate(
+ [result[self.norm_item] for result in results])
+ # normalize_factor: [N, 2]
+ normalize_factor = np.tile(normalize_factor_, [1, 2])
+ nme = keypoint_nme(pred_coords, gt_coords, mask, normalize_factor)
+ metrics['NME'] = nme
+
+ else:
+ if self.keypoint_indices is None:
+ # use default keypoint_indices in some datasets
+ dataset_name = self.dataset_meta['dataset_name']
+ if dataset_name not in self.DEFAULT_KEYPOINT_INDICES:
+ raise KeyError(
+ '`norm_mode` is set to `keypoint_distance`, and the '
+ 'keypoint_indices is set to None, can not find the '
+ 'keypoint_indices in `DEFAULT_KEYPOINT_INDICES`, '
+ 'please specify `keypoint_indices` appropriately.')
+ self.keypoint_indices = self.DEFAULT_KEYPOINT_INDICES[
+ dataset_name]
+ else:
+ assert len(self.keypoint_indices) == 2, 'The keypoint '\
+ 'indices used for normalization should be a pair.'
+ keypoint_id2name = self.dataset_meta['keypoint_id2name']
+ dataset_name = self.dataset_meta['dataset_name']
+ for idx in self.keypoint_indices:
+ assert idx in keypoint_id2name, f'The {dataset_name} '\
+ f'dataset does not contain the required '\
+ f'{idx}-th keypoint.'
+ # normalize_factor: [N, 2]
+ normalize_factor = self._get_normalize_factor(gt_coords=gt_coords)
+ nme = keypoint_nme(pred_coords, gt_coords, mask, normalize_factor)
+ metrics['NME'] = nme
+
+ return metrics
+
+ def _get_normalize_factor(self, gt_coords: np.ndarray) -> np.ndarray:
+ """Get the normalize factor. generally inter-ocular distance measured
+ as the Euclidean distance between the outer corners of the eyes is
+ used.
+
+ Args:
+ gt_coords (np.ndarray[N, K, 2]): Groundtruth keypoint coordinates.
+
+ Returns:
+ np.ndarray[N, 2]: normalized factor
+ """
+ idx1, idx2 = self.keypoint_indices
+
+ interocular = np.linalg.norm(
+ gt_coords[:, idx1, :] - gt_coords[:, idx2, :],
+ axis=1,
+ keepdims=True)
+
+ return np.tile(interocular, [1, 2])
diff --git a/mmpose/evaluation/metrics/keypoint_partition_metric.py b/mmpose/evaluation/metrics/keypoint_partition_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb30eca0d57f68e94cba93deec1f63bd333468aa
--- /dev/null
+++ b/mmpose/evaluation/metrics/keypoint_partition_metric.py
@@ -0,0 +1,203 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+from typing import Sequence
+
+import numpy as np
+from mmengine.evaluator import BaseMetric
+
+from mmpose.registry import METRICS
+
+
+@METRICS.register_module()
+class KeypointPartitionMetric(BaseMetric):
+ """Wrapper metric for evaluating pose metric on user-defined body parts.
+
+ Sometimes one may be interested in the performance of a pose model on
+ certain body parts rather than on all the keypoints. For example,
+ ``CocoWholeBodyMetric`` evaluates coco metric on body, foot, face,
+ lefthand and righthand. However, ``CocoWholeBodyMetric`` cannot be
+ applied to arbitrary custom datasets. This wrapper metric solves this
+ problem.
+
+ Supported metrics:
+ ``CocoMetric`` Note 1: all keypoint ground truth should be stored in
+ `keypoints` not other data fields. Note 2: `ann_file` is not
+ supported, it will be ignored. Note 3: `score_mode` other than
+ 'bbox' may produce results different from the
+ ``CocoWholebodyMetric``. Note 4: `nms_mode` other than 'none' may
+ produce results different from the ``CocoWholebodyMetric``.
+ ``PCKAccuracy`` Note 1: data fields required by ``PCKAccuracy`` should
+ be provided, such as bbox, head_size, etc. Note 2: In terms of
+ 'torso', since it is specifically designed for ``JhmdbDataset``, it is
+ not recommended to use it for other datasets.
+ ``AUC`` supported without limitations.
+ ``EPE`` supported without limitations.
+ ``NME`` only `norm_mode` = 'use_norm_item' is supported,
+ 'keypoint_distance' is incompatible with ``KeypointPartitionMetric``.
+
+ Incompatible metrics:
+ The following metrics are dataset specific metrics:
+ ``CocoWholeBodyMetric``
+ ``MpiiPCKAccuracy``
+ ``JhmdbPCKAccuracy``
+ ``PoseTrack18Metric``
+ Keypoint partitioning is included in these metrics.
+
+ Args:
+ metric (dict): arguments to instantiate a metric, please refer to the
+ arguments required by the metric of your choice.
+ partitions (dict): definition of body partitions. For example, if we
+ have 10 keypoints in total, the first 7 keypoints belong to body
+ and the last 3 keypoints belong to foot, this field can be like
+ this:
+ dict(
+ body=[0, 1, 2, 3, 4, 5, 6],
+ foot=[7, 8, 9],
+ all=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ )
+ where the numbers are the indices of keypoints and they can be
+ discontinuous.
+ """
+
+ def __init__(
+ self,
+ metric: dict,
+ partitions: dict,
+ ) -> None:
+ super().__init__()
+ # check metric type
+ supported_metric_types = [
+ 'CocoMetric', 'PCKAccuracy', 'AUC', 'EPE', 'NME'
+ ]
+ if metric['type'] not in supported_metric_types:
+ raise ValueError(
+ 'Metrics supported by KeypointPartitionMetric are CocoMetric, '
+ 'PCKAccuracy, AUC, EPE and NME, '
+ f"but got {metric['type']}")
+
+ # check CocoMetric arguments
+ if metric['type'] == 'CocoMetric':
+ if 'ann_file' in metric:
+ warnings.warn(
+ 'KeypointPartitionMetric does not support the ann_file '
+ 'argument of CocoMetric, this argument will be ignored.')
+ metric['ann_file'] = None
+ score_mode = metric.get('score_mode', 'bbox_keypoint')
+ if score_mode != 'bbox':
+ warnings.warn(
+ 'When using KeypointPartitionMetric with CocoMetric, '
+ "if score_mode is not 'bbox', pose scores will be "
+ "calculated part by part rather than by 'wholebody'. "
+ 'Therefore, this may produce results different from the '
+ 'CocoWholebodyMetric.')
+ nms_mode = metric.get('nms_mode', 'oks_nms')
+ if nms_mode != 'none':
+ warnings.warn(
+ 'When using KeypointPartitionMetric with CocoMetric, '
+ 'oks_nms and soft_oks_nms will be calculated part by part '
+ "rather than by 'wholebody'. Therefore, this may produce "
+ 'results different from the CocoWholebodyMetric.')
+
+ # check PCKAccuracy arguments
+ if metric['type'] == 'PCKAccuracy':
+ norm_item = metric.get('norm_item', 'bbox')
+ if norm_item == 'torso' or 'torso' in norm_item:
+ warnings.warn(
+ 'norm_item torso is used in JhmdbDataset, it may not be '
+ 'compatible with other datasets, use at your own risk.')
+
+ # check NME arguments
+ if metric['type'] == 'NME':
+ assert 'norm_mode' in metric, \
+ 'Missing norm_mode required by the NME metric.'
+ if metric['norm_mode'] != 'use_norm_item':
+ raise ValueError(
+ "NME norm_mode 'keypoint_distance' is incompatible with "
+ 'KeypointPartitionMetric.')
+
+ # check partitions
+ assert len(partitions) > 0, 'There should be at least one partition.'
+ for partition_name, partition in partitions.items():
+ assert isinstance(partition, Sequence), \
+ 'Each partition should be a sequence.'
+ assert len(partition) > 0, \
+ 'Each partition should have at least one element.'
+ self.partitions = partitions
+
+ # instantiate metrics for each partition
+ self.metrics = {}
+ for partition_name in partitions.keys():
+ _metric = deepcopy(metric)
+ if 'outfile_prefix' in _metric:
+ _metric['outfile_prefix'] = _metric[
+ 'outfile_prefix'] + '.' + partition_name
+ self.metrics[partition_name] = METRICS.build(_metric)
+
+ @BaseMetric.dataset_meta.setter
+ def dataset_meta(self, dataset_meta: dict) -> None:
+ """Set the dataset meta info to the metric."""
+ self._dataset_meta = dataset_meta
+ # sigmas required by coco metric have to be split as well
+ for partition_name, keypoint_ids in self.partitions.items():
+ _dataset_meta = deepcopy(dataset_meta)
+ _dataset_meta['num_keypoints'] = len(keypoint_ids)
+ _dataset_meta['sigmas'] = _dataset_meta['sigmas'][keypoint_ids]
+ self.metrics[partition_name].dataset_meta = _dataset_meta
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]) -> None:
+ """Split data samples by partitions, then call metric.process part by
+ part."""
+ parted_data_samples = {
+ partition_name: []
+ for partition_name in self.partitions.keys()
+ }
+ for data_sample in data_samples:
+ for partition_name, keypoint_ids in self.partitions.items():
+ _data_sample = deepcopy(data_sample)
+ if 'keypoint_scores' in _data_sample['pred_instances']:
+ _data_sample['pred_instances'][
+ 'keypoint_scores'] = _data_sample['pred_instances'][
+ 'keypoint_scores'][:, keypoint_ids]
+ _data_sample['pred_instances']['keypoints'] = _data_sample[
+ 'pred_instances']['keypoints'][:, keypoint_ids]
+ _data_sample['gt_instances']['keypoints'] = _data_sample[
+ 'gt_instances']['keypoints'][:, keypoint_ids]
+ _data_sample['gt_instances'][
+ 'keypoints_visible'] = _data_sample['gt_instances'][
+ 'keypoints_visible'][:, keypoint_ids]
+
+ # for coco metric
+ if 'raw_ann_info' in _data_sample:
+ raw_ann_info = _data_sample['raw_ann_info']
+ anns = raw_ann_info if isinstance(
+ raw_ann_info, list) else [raw_ann_info]
+ for ann in anns:
+ if 'keypoints' in ann:
+ keypoints = np.array(ann['keypoints']).reshape(
+ -1, 3)
+ keypoints = keypoints[keypoint_ids]
+ num_keypoints = np.sum(keypoints[:, 2] > 0)
+ ann['keypoints'] = keypoints.flatten().tolist()
+ ann['num_keypoints'] = num_keypoints
+
+ parted_data_samples[partition_name].append(_data_sample)
+
+ for partition_name, metric in self.metrics.items():
+ metric.process(data_batch, parted_data_samples[partition_name])
+
+ def compute_metrics(self, results: list) -> dict:
+ pass
+
+ def evaluate(self, size: int) -> dict:
+ """Run evaluation for each partition."""
+ eval_results = OrderedDict()
+ for partition_name, metric in self.metrics.items():
+ _eval_results = metric.evaluate(size)
+ for key in list(_eval_results.keys()):
+ new_key = partition_name + '/' + key
+ _eval_results[new_key] = _eval_results.pop(key)
+ eval_results.update(_eval_results)
+ return eval_results
diff --git a/mmpose/evaluation/metrics/posetrack18_metric.py b/mmpose/evaluation/metrics/posetrack18_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f801455a62467aaf45722210a6018c95b0bdd4
--- /dev/null
+++ b/mmpose/evaluation/metrics/posetrack18_metric.py
@@ -0,0 +1,220 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from typing import Dict, List, Optional
+
+import numpy as np
+from mmengine.fileio import dump, load
+from mmengine.logging import MMLogger
+
+from mmpose.registry import METRICS
+from .coco_metric import CocoMetric
+
+try:
+ from poseval import eval_helpers
+ from poseval.evaluateAP import evaluateAP
+ has_poseval = True
+except (ImportError, ModuleNotFoundError):
+ has_poseval = False
+
+
+@METRICS.register_module()
+class PoseTrack18Metric(CocoMetric):
+ """PoseTrack18 evaluation metric.
+
+ Evaluate AP, and mAP for keypoint detection tasks.
+ Support PoseTrack18 (video) dataset. Please refer to
+ ``__
+ for more details.
+
+ Args:
+ ann_file (str, optional): Path to the coco format annotation file.
+ If not specified, ground truth annotations from the dataset will
+ be converted to coco format. Defaults to None
+ score_mode (str): The mode to score the prediction results which
+ should be one of the following options:
+
+ - ``'bbox'``: Take the score of bbox as the score of the
+ prediction results.
+ - ``'bbox_keypoint'``: Use keypoint score to rescore the
+ prediction results.
+
+ Defaults to ``'bbox_keypoint'`
+ keypoint_score_thr (float): The threshold of keypoint score. The
+ keypoints with score lower than it will not be included to
+ rescore the prediction results. Valid only when ``score_mode`` is
+ ``bbox_keypoint``. Defaults to ``0.2``
+ nms_mode (str): The mode to perform Non-Maximum Suppression (NMS),
+ which should be one of the following options:
+
+ - ``'oks_nms'``: Use Object Keypoint Similarity (OKS) to
+ perform NMS.
+ - ``'soft_oks_nms'``: Use Object Keypoint Similarity (OKS)
+ to perform soft NMS.
+ - ``'none'``: Do not perform NMS. Typically for bottomup mode
+ output.
+
+ Defaults to ``'oks_nms'`
+ nms_thr (float): The Object Keypoint Similarity (OKS) threshold
+ used in NMS when ``nms_mode`` is ``'oks_nms'`` or
+ ``'soft_oks_nms'``. Will retain the prediction results with OKS
+ lower than ``nms_thr``. Defaults to ``0.9``
+ format_only (bool): Whether only format the output results without
+ doing quantitative evaluation. This is designed for the need of
+ test submission when the ground truth annotations are absent. If
+ set to ``True``, ``outfile_prefix`` should specify the path to
+ store the output results. Defaults to ``False``
+ outfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., ``'a/b/prefix'``.
+ If not specified, a temp file will be created. Defaults to ``None``
+ **kwargs: Keyword parameters passed to :class:`mmeval.BaseMetric`
+ """
+ default_prefix: Optional[str] = 'posetrack18'
+
+ def __init__(self,
+ ann_file: Optional[str] = None,
+ score_mode: str = 'bbox_keypoint',
+ keypoint_score_thr: float = 0.2,
+ nms_mode: str = 'oks_nms',
+ nms_thr: float = 0.9,
+ format_only: bool = False,
+ outfile_prefix: Optional[str] = None,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ # raise an error to avoid long time running without getting results
+ if not has_poseval:
+ raise ImportError('Please install ``poseval`` package for '
+ 'evaluation on PoseTrack dataset '
+ '(see `requirements/optional.txt`)')
+ super().__init__(
+ ann_file=ann_file,
+ score_mode=score_mode,
+ keypoint_score_thr=keypoint_score_thr,
+ nms_mode=nms_mode,
+ nms_thr=nms_thr,
+ format_only=format_only,
+ outfile_prefix=outfile_prefix,
+ collect_device=collect_device,
+ prefix=prefix)
+
+ def results2json(self, keypoints: Dict[int, list],
+ outfile_prefix: str) -> str:
+ """Dump the keypoint detection results into a json file.
+
+ Args:
+ keypoints (Dict[int, list]): Keypoint detection results
+ of the dataset.
+ outfile_prefix (str): The filename prefix of the json files.
+ If the prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json".
+
+ Returns:
+ str: The json file name of keypoint results.
+ """
+ categories = []
+
+ cat = {}
+ cat['supercategory'] = 'person'
+ cat['id'] = 1
+ cat['name'] = 'person'
+ cat['keypoints'] = [
+ 'nose', 'head_bottom', 'head_top', 'left_ear', 'right_ear',
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
+ 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee',
+ 'right_knee', 'left_ankle', 'right_ankle'
+ ]
+ cat['skeleton'] = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13],
+ [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10],
+ [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5],
+ [4, 6], [5, 7]]
+ categories.append(cat)
+
+ # path of directory for official gt files
+ gt_folder = osp.join(
+ osp.dirname(self.ann_file),
+ osp.splitext(self.ann_file.split('_')[-1])[0])
+ # the json file for each video sequence
+ json_files = [
+ pos for pos in os.listdir(gt_folder) if pos.endswith('.json')
+ ]
+
+ for json_file in json_files:
+ gt = load(osp.join(gt_folder, json_file))
+ annotations = []
+ images = []
+
+ for image in gt['images']:
+ img = {}
+ img['id'] = image['id']
+ img['file_name'] = image['file_name']
+ images.append(img)
+
+ img_kpts = keypoints[img['id']]
+
+ for track_id, img_kpt in enumerate(img_kpts):
+ ann = {}
+ ann['image_id'] = img_kpt['img_id']
+ ann['keypoints'] = np.array(
+ img_kpt['keypoints']).reshape(-1).tolist()
+ ann['scores'] = np.array(ann['keypoints']).reshape(
+ [-1, 3])[:, 2].tolist()
+ ann['score'] = float(img_kpt['score'])
+ ann['track_id'] = track_id
+ annotations.append(ann)
+
+ pred_file = osp.join(osp.dirname(outfile_prefix), json_file)
+ info = {}
+ info['images'] = images
+ info['categories'] = categories
+ info['annotations'] = annotations
+
+ dump(info, pred_file, sort_keys=True, indent=4)
+
+ def _do_python_keypoint_eval(self, outfile_prefix: str) -> List[tuple]:
+ """Do keypoint evaluation using `poseval` package.
+
+ Args:
+ outfile_prefix (str): The filename prefix of the json files.
+ If the prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.keypoints.json".
+
+ Returns:
+ list: a list of tuples. Each tuple contains the evaluation stats
+ name and corresponding stats value.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+
+ # path of directory for official gt files
+ # 'xxx/posetrack18_train.json' -> 'xxx/train/'
+ gt_folder = osp.join(
+ osp.dirname(self.ann_file),
+ osp.splitext(self.ann_file.split('_')[-1])[0])
+ pred_folder = osp.dirname(outfile_prefix)
+
+ argv = ['', gt_folder + '/', pred_folder + '/']
+
+ logger.info('Loading data')
+ gtFramesAll, prFramesAll = eval_helpers.load_data_dir(argv)
+
+ logger.info(f'# gt frames : {len(gtFramesAll)}')
+ logger.info(f'# pred frames: {len(prFramesAll)}')
+
+ # evaluate per-frame multi-person pose estimation (AP)
+ # compute AP
+ logger.info('Evaluation of per-frame multi-person pose estimation')
+ apAll, _, _ = evaluateAP(gtFramesAll, prFramesAll, None, False, False)
+
+ # print AP
+ logger.info('Average Precision (AP) metric:')
+ eval_helpers.printTable(apAll)
+
+ stats = eval_helpers.getCum(apAll)
+
+ stats_names = [
+ 'Head AP', 'Shou AP', 'Elb AP', 'Wri AP', 'Hip AP', 'Knee AP',
+ 'Ankl AP', 'AP'
+ ]
+
+ info_str = list(zip(stats_names, stats))
+
+ return info_str
diff --git a/mmpose/models/__init__.py b/mmpose/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e236f9928601328c6cd42817d842b034c4b9b13
--- /dev/null
+++ b/mmpose/models/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .backbones import * # noqa
+from .builder import (BACKBONES, HEADS, LOSSES, NECKS, build_backbone,
+ build_head, build_loss, build_neck, build_pose_estimator,
+ build_posenet)
+from .data_preprocessors import * # noqa
+from .heads import * # noqa
+from .losses import * # noqa
+from .necks import * # noqa
+from .pose_estimators import * # noqa
+
+__all__ = [
+ 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'build_backbone', 'build_head',
+ 'build_loss', 'build_posenet', 'build_neck', 'build_pose_estimator'
+]
diff --git a/mmpose/models/__pycache__/__init__.cpython-38.pyc b/mmpose/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbd830f81f73c0252b056888761e237100e6b4b9
Binary files /dev/null and b/mmpose/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/__pycache__/builder.cpython-38.pyc b/mmpose/models/__pycache__/builder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb117b14e01105b2cfa30261f6a07bbddf37255d
Binary files /dev/null and b/mmpose/models/__pycache__/builder.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb2498560a34afca0f3218eb9fb3d9a5dce04f33
--- /dev/null
+++ b/mmpose/models/backbones/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+from .cpm import CPM
+from .hourglass import HourglassNet
+from .hourglass_ae import HourglassAENet
+from .hrformer import HRFormer
+from .hrnet import HRNet
+from .litehrnet import LiteHRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .mspn import MSPN
+from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
+from .regnet import RegNet
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1d
+from .resnext import ResNeXt
+from .rsn import RSN
+from .scnet import SCNet
+from .seresnet import SEResNet
+from .seresnext import SEResNeXt
+from .shufflenet_v1 import ShuffleNetV1
+from .shufflenet_v2 import ShuffleNetV2
+from .swin import SwinTransformer
+from .tcn import TCN
+from .v2v_net import V2VNet
+from .vgg import VGG
+from .vipnas_mbv3 import ViPNAS_MobileNetV3
+from .vipnas_resnet import ViPNAS_ResNet
+
+__all__ = [
+ 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2',
+ 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
+ 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
+ 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
+ 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer',
+ 'PyramidVisionTransformerV2', 'SwinTransformer'
+]
diff --git a/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc b/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4e4cfedbe68164a189ed5944aac624262bdce8e
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02c57e557c0134d15251f5ac12b55f77d68b4df3
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/alexnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc b/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..094c6ee895026f56fbcc2e7bcbcf6321d96abe0e
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/base_backbone.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc b/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb8562d4df605418d4a928315f1b028dfa52e7df
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/cpm.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80f4d651cc22b66f19c213535969aca9d43aea9d
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hourglass.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..930671a1781da7dfa946dfd89092c2a6a614e18c
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hourglass_ae.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0454fdd8b97ac049f6af1f4f489494e254e8a74
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hrformer.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d6cf2922bcf62743f8301108b7336cb7ed76296
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/hrnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b221bf8e66196f8da93b0ce99eb01ea1aecccf67
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/litehrnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a19238d7c874b35916c03c0828c79d8ec8c6e54
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6efce920167c3aec9a4770240ac633f2dd6de743
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57485f1119da1c1d2d7117b2c20f5a3180d8bd97
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/mspn.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc b/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44f378804eacefbf068b182bcfb0ffdb57f77b1c
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/pvt.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a8b3a935e4a4fb583623f0d661a3302a44e8f0c
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/regnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7744c3dc0b1cbfffa7eb70f6ac41624fad87c33
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnest.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53c46e3a4d63845f435050e10ea36083680bcff0
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc b/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d12fe070429012319a668ed14ea6ce8e8bb10f24
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/resnext.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e89ef05de9f755ede5f823d3656dc21e3f433db1
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/rsn.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd163eea9dea0ba2e629a5dca6a5e35685ff3b52
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/scnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..225ff7dc153b2ceced76869a66954e4149b69e62
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/seresnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc b/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4d263daf00c53f9bc9bfae8c1717c6da5f0090b
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/seresnext.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc b/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ddc089fa4190b165f2ba0429a13a453bd438621
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc b/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6818af553035028af47fe2142ea04a947215a05a
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc b/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b29f78b9cceecb95a90606481d4712c18e2d1b5
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/swin.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc b/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58da95f1730f1f849bb082661ebdb905c58894cc
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/tcn.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc b/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60dd9065bab22797b21fe6dd1e8e9b488f52adea
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/v2v_net.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bd7809dd77d58bf6af1a9223f4d5e0ca80ed9c1
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vgg.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bf2f571b1c9317f5830ac3872621686222c6bd9
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vipnas_mbv3.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc b/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c6851030a84692aa8c5fb44b00472a4ca499e4c
Binary files /dev/null and b/mmpose/models/backbones/__pycache__/vipnas_resnet.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/alexnet.py b/mmpose/models/backbones/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2262658f4718a079b2effc276282be4d39fbe6ad
--- /dev/null
+++ b/mmpose/models/backbones/alexnet.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+@MODELS.register_module()
+class AlexNet(BaseBackbone):
+ """`AlexNet `__ backbone.
+
+ The input for AlexNet is a 224x224 RGB image.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ The default value is -1, which uses the backbone as
+ a feature extractor without the top classifier.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, num_classes=-1, init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return (x, )
diff --git a/mmpose/models/backbones/base_backbone.py b/mmpose/models/backbones/base_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..6094b4e831f992b052e4db206022f489a7f729b3
--- /dev/null
+++ b/mmpose/models/backbones/base_backbone.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+from mmengine.model import BaseModule
+
+
+class BaseBackbone(BaseModule, metaclass=ABCMeta):
+ """Base backbone.
+
+ This class defines the basic functions of a backbone. Any backbone that
+ inherits this class should at least define its own `forward` function.
+ """
+
+ @abstractmethod
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
+ torch.Tensor, containing input data for forward computation.
+ """
+
+ def train(self, mode=True):
+ """Set module status before forward computation.
+
+ Args:
+ mode (bool): Whether it is train_mode or test_mode
+ """
+ super(BaseBackbone, self).train(mode)
diff --git a/mmpose/models/backbones/cpm.py b/mmpose/models/backbones/cpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..256769c43a4d7b9d0cdd40fb6de19a90727012e8
--- /dev/null
+++ b/mmpose/models/backbones/cpm.py
@@ -0,0 +1,183 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class CpmBlock(BaseModule):
+ """CpmBlock for Convolutional Pose Machine.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ channels (list): Output channels of each conv module.
+ kernels (list): Kernel sizes of each conv module.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ channels=(128, 128, 128),
+ kernels=(11, 11, 11),
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ assert len(channels) == len(kernels)
+ layers = []
+ for i in range(len(channels)):
+ if i == 0:
+ input_channels = in_channels
+ else:
+ input_channels = channels[i - 1]
+ layers.append(
+ ConvModule(
+ input_channels,
+ channels[i],
+ kernels[i],
+ padding=(kernels[i] - 1) // 2,
+ norm_cfg=norm_cfg))
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ """Model forward function."""
+ out = self.model(x)
+ return out
+
+
+@MODELS.register_module()
+class CPM(BaseBackbone):
+ """CPM backbone.
+
+ Convolutional Pose Machines.
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ in_channels (int): The input channels of the CPM.
+ out_channels (int): The output channels of the CPM.
+ feat_channels (int): Feature channel of each CPM stage.
+ middle_channels (int): Feature channel of conv after the middle stage.
+ num_stages (int): Number of stages.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import CPM
+ >>> import torch
+ >>> self = CPM(3, 17)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 368, 368)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... print(tuple(level_output.shape))
+ (1, 17, 46, 46)
+ (1, 17, 46, 46)
+ (1, 17, 46, 46)
+ (1, 17, 46, 46)
+ (1, 17, 46, 46)
+ (1, 17, 46, 46)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ feat_channels=128,
+ middle_channels=32,
+ num_stages=6,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ assert in_channels == 3
+
+ self.num_stages = num_stages
+ assert self.num_stages >= 1
+
+ self.stem = nn.Sequential(
+ ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg),
+ ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg),
+ ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg),
+ ConvModule(512, out_channels, 1, padding=0, act_cfg=None))
+
+ self.middle = nn.Sequential(
+ ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+
+ self.cpm_stages = nn.ModuleList([
+ CpmBlock(
+ middle_channels + out_channels,
+ channels=[feat_channels, feat_channels, feat_channels],
+ kernels=[11, 11, 11],
+ norm_cfg=norm_cfg) for _ in range(num_stages - 1)
+ ])
+
+ self.middle_conv = nn.ModuleList([
+ nn.Sequential(
+ ConvModule(
+ 128, middle_channels, 5, padding=2, norm_cfg=norm_cfg))
+ for _ in range(num_stages - 1)
+ ])
+
+ self.out_convs = nn.ModuleList([
+ nn.Sequential(
+ ConvModule(
+ feat_channels,
+ feat_channels,
+ 1,
+ padding=0,
+ norm_cfg=norm_cfg),
+ ConvModule(feat_channels, out_channels, 1, act_cfg=None))
+ for _ in range(num_stages - 1)
+ ])
+
+ def forward(self, x):
+ """Model forward function."""
+ stage1_out = self.stem(x)
+ middle_out = self.middle(x)
+ out_feats = []
+
+ out_feats.append(stage1_out)
+
+ for ind in range(self.num_stages - 1):
+ single_stage = self.cpm_stages[ind]
+ out_conv = self.out_convs[ind]
+
+ inp_feat = torch.cat(
+ [out_feats[-1], self.middle_conv[ind](middle_out)], 1)
+ cpm_feat = single_stage(inp_feat)
+ out_feat = out_conv(cpm_feat)
+ out_feats.append(out_feat)
+
+ return out_feats
diff --git a/mmpose/models/backbones/hourglass.py b/mmpose/models/backbones/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfc8d6d328da5b63094015351cc10084cda46da0
--- /dev/null
+++ b/mmpose/models/backbones/hourglass.py
@@ -0,0 +1,209 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .resnet import BasicBlock, ResLayer
+
+
+class HourglassModule(BaseModule):
+ """Hourglass Module for HourglassNet backbone.
+
+ Generate module recursively and use BasicBlock as the base unit.
+
+ Args:
+ depth (int): Depth of current HourglassModule.
+ stage_channels (list[int]): Feature channels of sub-modules in current
+ and follow-up HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in current and
+ follow-up HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ depth,
+ stage_channels,
+ stage_blocks,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ self.depth = depth
+
+ cur_block = stage_blocks[0]
+ next_block = stage_blocks[1]
+
+ cur_channel = stage_channels[0]
+ next_channel = stage_channels[1]
+
+ self.up1 = ResLayer(
+ BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg)
+
+ self.low1 = ResLayer(
+ BasicBlock,
+ cur_block,
+ cur_channel,
+ next_channel,
+ stride=2,
+ norm_cfg=norm_cfg)
+
+ if self.depth > 1:
+ self.low2 = HourglassModule(depth - 1, stage_channels[1:],
+ stage_blocks[1:])
+ else:
+ self.low2 = ResLayer(
+ BasicBlock,
+ next_block,
+ next_channel,
+ next_channel,
+ norm_cfg=norm_cfg)
+
+ self.low3 = ResLayer(
+ BasicBlock,
+ cur_block,
+ next_channel,
+ cur_channel,
+ norm_cfg=norm_cfg,
+ downsample_first=False)
+
+ self.up2 = nn.Upsample(scale_factor=2)
+
+ def forward(self, x):
+ """Model forward function."""
+ up1 = self.up1(x)
+ low1 = self.low1(x)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ up2 = self.up2(low3)
+ return up1 + up2
+
+
+@MODELS.register_module()
+class HourglassNet(BaseBackbone):
+ """HourglassNet backbone.
+
+ Stacked Hourglass Networks for Human Pose Estimation.
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ downsample_times (int): Downsample times in a HourglassModule.
+ num_stacks (int): Number of HourglassModule modules stacked,
+ 1 for Hourglass-52, 2 for Hourglass-104.
+ stage_channels (list[int]): Feature channel of each sub-module in a
+ HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in a
+ HourglassModule.
+ feat_channel (int): Feature channel of conv after a HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import HourglassNet
+ >>> import torch
+ >>> self = HourglassNet()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 511, 511)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... print(tuple(level_output.shape))
+ (1, 256, 128, 128)
+ (1, 256, 128, 128)
+ """
+
+ def __init__(
+ self,
+ downsample_times=5,
+ num_stacks=2,
+ stage_channels=(256, 256, 384, 384, 384, 512),
+ stage_blocks=(2, 2, 2, 2, 2, 4),
+ feat_channel=256,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ self.num_stacks = num_stacks
+ assert self.num_stacks >= 1
+ assert len(stage_channels) == len(stage_blocks)
+ assert len(stage_channels) > downsample_times
+
+ cur_channel = stage_channels[0]
+
+ self.stem = nn.Sequential(
+ ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
+ ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg))
+
+ self.hourglass_modules = nn.ModuleList([
+ HourglassModule(downsample_times, stage_channels, stage_blocks)
+ for _ in range(num_stacks)
+ ])
+
+ self.inters = ResLayer(
+ BasicBlock,
+ num_stacks - 1,
+ cur_channel,
+ cur_channel,
+ norm_cfg=norm_cfg)
+
+ self.conv1x1s = nn.ModuleList([
+ ConvModule(
+ cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.out_convs = nn.ModuleList([
+ ConvModule(
+ cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
+ for _ in range(num_stacks)
+ ])
+
+ self.remap_convs = nn.ModuleList([
+ ConvModule(
+ feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Model forward function."""
+ inter_feat = self.stem(x)
+ out_feats = []
+
+ for ind in range(self.num_stacks):
+ single_hourglass = self.hourglass_modules[ind]
+ out_conv = self.out_convs[ind]
+
+ hourglass_feat = single_hourglass(inter_feat)
+ out_feat = out_conv(hourglass_feat)
+ out_feats.append(out_feat)
+
+ if ind < self.num_stacks - 1:
+ inter_feat = self.conv1x1s[ind](
+ inter_feat) + self.remap_convs[ind](
+ out_feat)
+ inter_feat = self.inters[ind](self.relu(inter_feat))
+
+ return out_feats
diff --git a/mmpose/models/backbones/hourglass_ae.py b/mmpose/models/backbones/hourglass_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..93e62dd4067c3489de00c5cd1f7875489725de2e
--- /dev/null
+++ b/mmpose/models/backbones/hourglass_ae.py
@@ -0,0 +1,209 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, MaxPool2d
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class HourglassAEModule(BaseModule):
+ """Modified Hourglass Module for HourglassNet_AE backbone.
+
+ Generate module recursively and use BasicBlock as the base unit.
+
+ Args:
+ depth (int): Depth of current HourglassModule.
+ stage_channels (list[int]): Feature channels of sub-modules in current
+ and follow-up HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ depth,
+ stage_channels,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ self.depth = depth
+
+ cur_channel = stage_channels[0]
+ next_channel = stage_channels[1]
+
+ self.up1 = ConvModule(
+ cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.pool1 = MaxPool2d(2, 2)
+
+ self.low1 = ConvModule(
+ cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
+
+ if self.depth > 1:
+ self.low2 = HourglassAEModule(depth - 1, stage_channels[1:])
+ else:
+ self.low2 = ConvModule(
+ next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.low3 = ConvModule(
+ next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
+
+ def forward(self, x):
+ """Model forward function."""
+ up1 = self.up1(x)
+ pool1 = self.pool1(x)
+ low1 = self.low1(pool1)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ up2 = self.up2(low3)
+ return up1 + up2
+
+
+@MODELS.register_module()
+class HourglassAENet(BaseBackbone):
+ """Hourglass-AE Network proposed by Newell et al.
+
+ Associative Embedding: End-to-End Learning for Joint
+ Detection and Grouping.
+
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ downsample_times (int): Downsample times in a HourglassModule.
+ num_stacks (int): Number of HourglassModule modules stacked,
+ 1 for Hourglass-52, 2 for Hourglass-104.
+ stage_channels (list[int]): Feature channel of each sub-module in a
+ HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in a
+ HourglassModule.
+ feat_channels (int): Feature channel of conv after a HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import HourglassAENet
+ >>> import torch
+ >>> self = HourglassAENet()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 512, 512)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... print(tuple(level_output.shape))
+ (1, 34, 128, 128)
+ """
+
+ def __init__(
+ self,
+ downsample_times=4,
+ num_stacks=1,
+ out_channels=34,
+ stage_channels=(256, 384, 512, 640, 768),
+ feat_channels=256,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ self.num_stacks = num_stacks
+ assert self.num_stacks >= 1
+ assert len(stage_channels) > downsample_times
+
+ cur_channels = stage_channels[0]
+
+ self.stem = nn.Sequential(
+ ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg),
+ ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg),
+ MaxPool2d(2, 2),
+ ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg),
+ ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg),
+ )
+
+ self.hourglass_modules = nn.ModuleList([
+ nn.Sequential(
+ HourglassAEModule(
+ downsample_times, stage_channels, norm_cfg=norm_cfg),
+ ConvModule(
+ feat_channels,
+ feat_channels,
+ 3,
+ padding=1,
+ norm_cfg=norm_cfg),
+ ConvModule(
+ feat_channels,
+ feat_channels,
+ 3,
+ padding=1,
+ norm_cfg=norm_cfg)) for _ in range(num_stacks)
+ ])
+
+ self.out_convs = nn.ModuleList([
+ ConvModule(
+ cur_channels,
+ out_channels,
+ 1,
+ padding=0,
+ norm_cfg=None,
+ act_cfg=None) for _ in range(num_stacks)
+ ])
+
+ self.remap_out_convs = nn.ModuleList([
+ ConvModule(
+ out_channels,
+ feat_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=None) for _ in range(num_stacks - 1)
+ ])
+
+ self.remap_feature_convs = nn.ModuleList([
+ ConvModule(
+ feat_channels,
+ feat_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=None) for _ in range(num_stacks - 1)
+ ])
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Model forward function."""
+ inter_feat = self.stem(x)
+ out_feats = []
+
+ for ind in range(self.num_stacks):
+ single_hourglass = self.hourglass_modules[ind]
+ out_conv = self.out_convs[ind]
+
+ hourglass_feat = single_hourglass(inter_feat)
+ out_feat = out_conv(hourglass_feat)
+ out_feats.append(out_feat)
+
+ if ind < self.num_stacks - 1:
+ inter_feat = inter_feat + self.remap_out_convs[ind](
+ out_feat) + self.remap_feature_convs[ind](
+ hourglass_feat)
+
+ return out_feats
diff --git a/mmpose/models/backbones/hrformer.py b/mmpose/models/backbones/hrformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4712cfdfb5c0454e00d7e2e09b244424f2c80f5a
--- /dev/null
+++ b/mmpose/models/backbones/hrformer.py
@@ -0,0 +1,763 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
+from mmcv.cnn.bricks.transformer import build_dropout
+from mmengine.model import BaseModule, trunc_normal_init
+from torch.nn.functional import pad
+
+from mmpose.registry import MODELS
+from .hrnet import Bottleneck, HRModule, HRNet
+
+
+def nlc_to_nchw(x, hw_shape):
+ """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, L, C] before conversion.
+ hw_shape (Sequence[int]): The height and width of output feature map.
+
+ Returns:
+ Tensor: The output tensor of shape [N, C, H, W] after conversion.
+ """
+ H, W = hw_shape
+ assert len(x.shape) == 3
+ B, L, C = x.shape
+ assert L == H * W, 'The seq_len doesn\'t match H, W'
+ return x.transpose(1, 2).reshape(B, C, H, W)
+
+
+def nchw_to_nlc(x):
+ """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
+
+ Returns:
+ Tensor: The output tensor of shape [N, L, C] after conversion.
+ """
+ assert len(x.shape) == 4
+ return x.flatten(2).transpose(1, 2).contiguous()
+
+
+def build_drop_path(drop_path_rate):
+ """Build drop path layer."""
+ return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate))
+
+
+class WindowMSA(BaseModule):
+ """Window based multi-head self-attention (W-MSA) module with relative
+ position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ with_rpe (bool, optional): If True, use relative position bias.
+ Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ with_rpe=True,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+
+ self.with_rpe = with_rpe
+ if self.with_rpe:
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(
+ (2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ Wh, Ww = self.window_size
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
+ rel_position_index = rel_index_coords + rel_index_coords.T
+ rel_position_index = rel_position_index.flip(1).contiguous()
+ self.register_buffer('relative_position_index', rel_position_index)
+
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def init_weights(self):
+ trunc_normal_init(self.relative_position_bias_table, std=0.02)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+
+ x (tensor): input features with shape of (B*num_windows, N, C)
+ mask (tensor | None, Optional): mask with shape of (num_windows,
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
+ """
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.with_rpe:
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class LocalWindowSelfAttention(BaseModule):
+ r""" Local-window Self Attention (LSA) module with relative position bias.
+
+ This module is the short-range self-attention module in the
+ Interlaced Sparse Self-Attention `_.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int] | int): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ with_rpe (bool, optional): If True, use relative position bias.
+ Default: True.
+ with_pad_mask (bool, optional): If True, mask out the padded tokens in
+ the attention process. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ with_rpe=True,
+ with_pad_mask=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ self.window_size = window_size
+ self.with_pad_mask = with_pad_mask
+ self.attn = WindowMSA(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate,
+ with_rpe=with_rpe,
+ init_cfg=init_cfg)
+
+ def forward(self, x, H, W, **kwargs):
+ """Forward function."""
+ B, N, C = x.shape
+ x = x.view(B, H, W, C)
+ Wh, Ww = self.window_size
+
+ # center-pad the feature on H and W axes
+ pad_h = math.ceil(H / Wh) * Wh - H
+ pad_w = math.ceil(W / Ww) * Ww - W
+ x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2))
+
+ # permute
+ x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C)
+ x = x.permute(0, 1, 3, 2, 4, 5)
+ x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C)
+
+ # attention
+ if self.with_pad_mask and pad_h > 0 and pad_w > 0:
+ pad_mask = x.new_zeros(1, H, W, 1)
+ pad_mask = pad(
+ pad_mask, [
+ 0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2
+ ],
+ value=-float('inf'))
+ pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh,
+ math.ceil(W / Ww), Ww, 1)
+ pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5)
+ pad_mask = pad_mask.reshape(-1, Wh * Ww)
+ pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1])
+ out = self.attn(x, pad_mask, **kwargs)
+ else:
+ out = self.attn(x, **kwargs)
+
+ # reverse permutation
+ out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C)
+ out = out.permute(0, 1, 3, 2, 4, 5)
+ out = out.reshape(B, H + pad_h, W + pad_w, C)
+
+ # de-pad
+ out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2]
+ return out.reshape(B, N, C)
+
+
+class CrossFFN(BaseModule):
+ r"""FFN with Depthwise Conv of HRFormer.
+
+ Args:
+ in_features (int): The feature dimension.
+ hidden_features (int, optional): The hidden dimension of FFNs.
+ Defaults: The same as in_features.
+ act_cfg (dict, optional): Config of activation layer.
+ Default: dict(type='GELU').
+ dw_act_cfg (dict, optional): Config of activation layer appended
+ right after DW Conv. Default: dict(type='GELU').
+ norm_cfg (dict, optional): Config of norm layer.
+ Default: dict(type='SyncBN').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ dw_act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
+ self.act1 = build_activation_layer(act_cfg)
+ self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
+ self.dw3x3 = nn.Conv2d(
+ hidden_features,
+ hidden_features,
+ kernel_size=3,
+ stride=1,
+ groups=hidden_features,
+ padding=1)
+ self.act2 = build_activation_layer(dw_act_cfg)
+ self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
+ self.act3 = build_activation_layer(act_cfg)
+ self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
+
+ # put the modules togather
+ self.layers = [
+ self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2,
+ self.fc2, self.norm3, self.act3
+ ]
+
+ def forward(self, x, H, W):
+ """Forward function."""
+ x = nlc_to_nchw(x, (H, W))
+ for layer in self.layers:
+ x = layer(x)
+ x = nchw_to_nlc(x)
+ return x
+
+
+class HRFormerBlock(BaseModule):
+ """High-Resolution Block for HRFormer.
+
+ Args:
+ in_features (int): The input dimension.
+ out_features (int): The output dimension.
+ num_heads (int): The number of head within each LSA.
+ window_size (int, optional): The window size for the LSA.
+ Default: 7
+ mlp_ratio (int, optional): The expansion ration of FFN.
+ Default: 4
+ act_cfg (dict, optional): Config of activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): Config of norm layer.
+ Default: dict(type='SyncBN').
+ transformer_norm_cfg (dict, optional): Config of transformer norm
+ layer. Default: dict(type='LN', eps=1e-6).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ expansion = 1
+
+ def __init__(self,
+ in_features,
+ out_features,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.0,
+ drop_path=0.0,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN'),
+ transformer_norm_cfg=dict(type='LN', eps=1e-6),
+ init_cfg=None,
+ **kwargs):
+ super(HRFormerBlock, self).__init__(init_cfg=init_cfg)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+
+ self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1]
+ self.attn = LocalWindowSelfAttention(
+ in_features,
+ num_heads=num_heads,
+ window_size=window_size,
+ init_cfg=None,
+ **kwargs)
+
+ self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1]
+ self.ffn = CrossFFN(
+ in_features=in_features,
+ hidden_features=int(in_features * mlp_ratio),
+ out_features=out_features,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dw_act_cfg=act_cfg,
+ init_cfg=None)
+
+ self.drop_path = build_drop_path(
+ drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ """Forward function."""
+ B, C, H, W = x.size()
+ # Attention
+ x = x.view(B, C, -1).permute(0, 2, 1)
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ # FFN
+ x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
+ x = x.permute(0, 2, 1).view(B, C, H, W)
+ return x
+
+ def extra_repr(self):
+ """(Optional) Set the extra information about this module."""
+ return 'num_heads={}, window_size={}, mlp_ratio={}'.format(
+ self.num_heads, self.window_size, self.mlp_ratio)
+
+
+class HRFomerModule(HRModule):
+ """High-Resolution Module for HRFormer.
+
+ Args:
+ num_branches (int): The number of branches in the HRFormerModule.
+ block (nn.Module): The building block of HRFormer.
+ The block should be the HRFormerBlock.
+ num_blocks (tuple): The number of blocks in each branch.
+ The length must be equal to num_branches.
+ num_inchannels (tuple): The number of input channels in each branch.
+ The length must be equal to num_branches.
+ num_channels (tuple): The number of channels in each branch.
+ The length must be equal to num_branches.
+ num_heads (tuple): The number of heads within the LSAs.
+ num_window_sizes (tuple): The window size for the LSAs.
+ num_mlp_ratios (tuple): The expansion ratio for the FFNs.
+ drop_path (int, optional): The drop path rate of HRFomer.
+ Default: 0.0
+ multiscale_output (bool, optional): Whether to output multi-level
+ features produced by multiple branches. If False, only the first
+ level feature will be output. Default: True.
+ conv_cfg (dict, optional): Config of the conv layers.
+ Default: None.
+ norm_cfg (dict, optional): Config of the norm layers appended
+ right after conv. Default: dict(type='SyncBN', requires_grad=True)
+ transformer_norm_cfg (dict, optional): Config of the norm layers.
+ Default: dict(type='LN', eps=1e-6)
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False
+ upsample_cfg(dict, optional): The config of upsample layers in fuse
+ layers. Default: dict(mode='bilinear', align_corners=False)
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ num_heads,
+ num_window_sizes,
+ num_mlp_ratios,
+ multiscale_output=True,
+ drop_paths=0.0,
+ with_rpe=True,
+ with_pad_mask=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ transformer_norm_cfg=dict(type='LN', eps=1e-6),
+ with_cp=False,
+ upsample_cfg=dict(mode='bilinear', align_corners=False),
+ **kwargs):
+
+ self.transformer_norm_cfg = transformer_norm_cfg
+ self.drop_paths = drop_paths
+ self.num_heads = num_heads
+ self.num_window_sizes = num_window_sizes
+ self.num_mlp_ratios = num_mlp_ratios
+ self.with_rpe = with_rpe
+ self.with_pad_mask = with_pad_mask
+
+ super().__init__(num_branches, block, num_blocks, num_inchannels,
+ num_channels, multiscale_output, with_cp, conv_cfg,
+ norm_cfg, upsample_cfg, **kwargs)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Build one branch."""
+ # HRFormerBlock does not support down sample layer yet.
+ assert stride == 1 and self.in_channels[branch_index] == num_channels[
+ branch_index]
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ num_heads=self.num_heads[branch_index],
+ window_size=self.num_window_sizes[branch_index],
+ mlp_ratio=self.num_mlp_ratios[branch_index],
+ drop_path=self.drop_paths[0],
+ norm_cfg=self.norm_cfg,
+ transformer_norm_cfg=self.transformer_norm_cfg,
+ init_cfg=None,
+ with_rpe=self.with_rpe,
+ with_pad_mask=self.with_pad_mask))
+
+ self.in_channels[
+ branch_index] = self.in_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ num_heads=self.num_heads[branch_index],
+ window_size=self.num_window_sizes[branch_index],
+ mlp_ratio=self.num_mlp_ratios[branch_index],
+ drop_path=self.drop_paths[i],
+ norm_cfg=self.norm_cfg,
+ transformer_norm_cfg=self.transformer_norm_cfg,
+ init_cfg=None,
+ with_rpe=self.with_rpe,
+ with_pad_mask=self.with_pad_mask))
+ return nn.Sequential(*layers)
+
+ def _make_fuse_layers(self):
+ """Build fuse layers."""
+ if self.num_branches == 1:
+ return None
+ num_branches = self.num_branches
+ num_inchannels = self.in_channels
+ fuse_layers = []
+ for i in range(num_branches if self.multiscale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_inchannels[j],
+ num_inchannels[i],
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_inchannels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i),
+ mode=self.upsample_cfg['mode'],
+ align_corners=self.
+ upsample_cfg['align_corners'])))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ with_out_act = False
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ with_out_act = True
+ sub_modules = [
+ build_conv_layer(
+ self.conv_cfg,
+ num_inchannels[j],
+ num_inchannels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=num_inchannels[j],
+ bias=False,
+ ),
+ build_norm_layer(self.norm_cfg,
+ num_inchannels[j])[1],
+ build_conv_layer(
+ self.conv_cfg,
+ num_inchannels[j],
+ num_outchannels_conv3x3,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ build_norm_layer(self.norm_cfg,
+ num_outchannels_conv3x3)[1]
+ ]
+ if with_out_act:
+ sub_modules.append(nn.ReLU(False))
+ conv3x3s.append(nn.Sequential(*sub_modules))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ """Return the number of input channels."""
+ return self.in_channels
+
+
+@MODELS.register_module()
+class HRFormer(HRNet):
+ """HRFormer backbone.
+
+ This backbone is the implementation of `HRFormer: High-Resolution
+ Transformer for Dense Prediction `_.
+
+ Args:
+ extra (dict): Detailed configuration for each stage of HRNet.
+ There must be 4 stages, the configuration for each stage must have
+ 5 keys:
+
+ - num_modules (int): The number of HRModule in this stage.
+ - num_branches (int): The number of branches in the HRModule.
+ - block (str): The type of block.
+ - num_blocks (tuple): The number of blocks in each branch.
+ The length must be equal to num_branches.
+ - num_channels (tuple): The number of channels in each branch.
+ The length must be equal to num_branches.
+ in_channels (int): Number of input image channels. Normally 3.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): Config of norm layer.
+ Use `SyncBN` by default.
+ transformer_norm_cfg (dict): Config of transformer norm layer.
+ Use `LN` by default.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import HRFormer
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(2, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='HRFORMER',
+ >>> window_sizes=(7, 7),
+ >>> num_heads=(1, 2),
+ >>> mlp_ratios=(4, 4),
+ >>> num_blocks=(2, 2),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='HRFORMER',
+ >>> window_sizes=(7, 7, 7),
+ >>> num_heads=(1, 2, 4),
+ >>> mlp_ratios=(4, 4, 4),
+ >>> num_blocks=(2, 2, 2),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=2,
+ >>> num_branches=4,
+ >>> block='HRFORMER',
+ >>> window_sizes=(7, 7, 7, 7),
+ >>> num_heads=(1, 2, 4, 8),
+ >>> mlp_ratios=(4, 4, 4, 4),
+ >>> num_blocks=(2, 2, 2, 2),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRFormer(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock}
+
+ def __init__(
+ self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ transformer_norm_cfg=dict(type='LN', eps=1e-6),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False,
+ frozen_stages=-1,
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+
+ # stochastic depth
+ depths = [
+ extra[stage]['num_blocks'][0] * extra[stage]['num_modules']
+ for stage in ['stage2', 'stage3', 'stage4']
+ ]
+ depth_s2, depth_s3, _ = depths
+ drop_path_rate = extra['drop_path_rate']
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ]
+ extra['stage2']['drop_path_rates'] = dpr[0:depth_s2]
+ extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3]
+ extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:]
+
+ # HRFormer use bilinear upsample as default
+ upsample_cfg = extra.get('upsample', {
+ 'mode': 'bilinear',
+ 'align_corners': False
+ })
+ extra['upsample'] = upsample_cfg
+ self.transformer_norm_cfg = transformer_norm_cfg
+ self.with_rpe = extra.get('with_rpe', True)
+ self.with_pad_mask = extra.get('with_pad_mask', False)
+
+ super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
+ with_cp, zero_init_residual, frozen_stages, init_cfg)
+
+ def _make_stage(self,
+ layer_config,
+ num_inchannels,
+ multiscale_output=True):
+ """Make each stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+ num_heads = layer_config['num_heads']
+ num_window_sizes = layer_config['window_sizes']
+ num_mlp_ratios = layer_config['mlp_ratios']
+ drop_path_rates = layer_config['drop_path_rates']
+
+ modules = []
+ for i in range(num_modules):
+ # multiscale_output is only used at the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ modules.append(
+ HRFomerModule(
+ num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ num_heads,
+ num_window_sizes,
+ num_mlp_ratios,
+ reset_multiscale_output,
+ drop_paths=drop_path_rates[num_blocks[0] *
+ i:num_blocks[0] * (i + 1)],
+ with_rpe=self.with_rpe,
+ with_pad_mask=self.with_pad_mask,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ transformer_norm_cfg=self.transformer_norm_cfg,
+ with_cp=self.with_cp,
+ upsample_cfg=self.upsample_cfg))
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
diff --git a/mmpose/models/backbones/hrnet.py b/mmpose/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..381b22d60ec886ecb6d8c52fc9e7ccab52c05e99
--- /dev/null
+++ b/mmpose/models/backbones/hrnet.py
@@ -0,0 +1,610 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule, constant_init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .resnet import BasicBlock, Bottleneck, get_expansion
+
+
+class HRModule(BaseModule):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=False,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ upsample_cfg=dict(mode='nearest', align_corners=None),
+ init_cfg=None):
+
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.upsample_cfg = upsample_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=True)
+
+ @staticmethod
+ def _check_branches(num_branches, num_blocks, in_channels, num_channels):
+ """Check input to avoid ValueError."""
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_BLOCKS({len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_CHANNELS({len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Make one branch."""
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * get_expansion(block):
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * get_expansion(block),
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(
+ self.norm_cfg,
+ num_channels[branch_index] * get_expansion(block))[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index] * get_expansion(block),
+ stride=stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * get_expansion(block)
+ for _ in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index] * get_expansion(block),
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ """Make branches."""
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ """Make fuse layer."""
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i),
+ mode=self.upsample_cfg['mode'],
+ align_corners=self.
+ upsample_cfg['align_corners'])))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=True)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@MODELS.register_module()
+class HRNet(BaseBackbone):
+ """HRNet backbone.
+
+ `High-Resolution Representations for Labeling Pixels and Regions
+ `__
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Default: 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(
+ self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False,
+ frozen_stages=-1,
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.init_cfg = init_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+ self.frozen_stages = frozen_stages
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.upsample_cfg = self.extra.get('upsample', {
+ 'mode': 'nearest',
+ 'align_corners': None
+ })
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * get_expansion(block)
+ self.layer1 = self._make_layer(block, 64, stage1_out_channels,
+ num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [
+ channel * get_expansion(block) for channel in num_channels
+ ]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [
+ channel * get_expansion(block) for channel in num_channels
+ ]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [
+ channel * get_expansion(block) for channel in num_channels
+ ]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg,
+ num_channels,
+ multiscale_output=self.stage4_cfg.get('multiscale_output', False))
+
+ self._freeze_stages()
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, in_channels, out_channels, blocks, stride=1):
+ """Make layer."""
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1])
+
+ layers = []
+ layers.append(
+ block(
+ in_channels,
+ out_channels,
+ stride=stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ out_channels,
+ out_channels,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ """Make stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ upsample_cfg=self.upsample_cfg))
+
+ in_channels = hr_modules[-1].in_channels
+
+ return nn.Sequential(*hr_modules), in_channels
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.norm1.eval()
+ self.norm2.eval()
+
+ for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ if i == 1:
+ m = getattr(self, 'layer1')
+ else:
+ m = getattr(self, f'stage{i}')
+
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if i < 4:
+ m = getattr(self, f'transition{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone."""
+ super(HRNet, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress zero_init_residual if use pretrained model.
+ return
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return tuple(y_list)
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/litehrnet.py b/mmpose/models/backbones/litehrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad5f63014553129a02ca3dc4bfda4c181fcd6a6
--- /dev/null
+++ b/mmpose/models/backbones/litehrnet.py
@@ -0,0 +1,999 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/HRNet/Lite-HRNet
+# Original licence: Apache License 2.0.
+# ------------------------------------------------------------------------------
+
+import mmengine
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
+ build_conv_layer, build_norm_layer)
+from mmengine.model import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import channel_shuffle
+
+
+class SpatialWeighting(BaseModule):
+ """Spatial weighting module.
+
+ Args:
+ channels (int): The channels of the module.
+ ratio (int): channel reduction ratio.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: (dict(type='ReLU'), dict(type='Sigmoid')).
+ The last ConvModule uses Sigmoid by default.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmengine.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=int(channels / ratio),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=int(channels / ratio),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
+
+
+class CrossResolutionWeighting(BaseModule):
+ """Cross-resolution channel weighting module.
+
+ Args:
+ channels (int): The channels of the module.
+ ratio (int): channel reduction ratio.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: (dict(type='ReLU'), dict(type='Sigmoid')).
+ The last ConvModule uses Sigmoid by default.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmengine.is_tuple_of(act_cfg, dict)
+ self.channels = channels
+ total_channel = sum(channels)
+ self.conv1 = ConvModule(
+ in_channels=total_channel,
+ out_channels=int(total_channel / ratio),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=int(total_channel / ratio),
+ out_channels=total_channel,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ mini_size = x[-1].size()[-2:]
+ out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]]
+ out = torch.cat(out, dim=1)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ out = torch.split(out, self.channels, dim=1)
+ out = [
+ s * F.interpolate(a, size=s.size()[-2:], mode='nearest')
+ for s, a in zip(x, out)
+ ]
+ return out
+
+
+class ConditionalChannelWeighting(BaseModule):
+ """Conditional channel weighting block.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ stride (int): Stride of the 3x3 convolution layer.
+ reduce_ratio (int): channel reduction ratio.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ stride,
+ reduce_ratio,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.with_cp = with_cp
+ self.stride = stride
+ assert stride in [1, 2]
+
+ branch_channels = [channel // 2 for channel in in_channels]
+
+ self.cross_resolution_weighting = CrossResolutionWeighting(
+ branch_channels,
+ ratio=reduce_ratio,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg)
+
+ self.depthwise_convs = nn.ModuleList([
+ ConvModule(
+ channel,
+ channel,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ groups=channel,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None) for channel in branch_channels
+ ])
+
+ self.spatial_weighting = nn.ModuleList([
+ SpatialWeighting(channels=channel, ratio=4)
+ for channel in branch_channels
+ ])
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ x = [s.chunk(2, dim=1) for s in x]
+ x1 = [s[0] for s in x]
+ x2 = [s[1] for s in x]
+
+ x2 = self.cross_resolution_weighting(x2)
+ x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
+ x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
+
+ out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
+ out = [channel_shuffle(s, 2) for s in out]
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class Stem(BaseModule):
+ """Stem network block.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ stem_channels (int): Output channels of the stem layer.
+ out_channels (int): The output channels of the block.
+ expand_ratio (int): adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ stem_channels,
+ out_channels,
+ expand_ratio,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+
+ self.conv1 = ConvModule(
+ in_channels=in_channels,
+ out_channels=stem_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='ReLU'))
+
+ mid_channels = int(round(stem_channels * expand_ratio))
+ branch_channels = stem_channels // 2
+ if stem_channels == self.out_channels:
+ inc_channels = self.out_channels - branch_channels
+ else:
+ inc_channels = self.out_channels - stem_channels
+
+ self.branch1 = nn.Sequential(
+ ConvModule(
+ branch_channels,
+ branch_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=branch_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None),
+ ConvModule(
+ branch_channels,
+ inc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU')),
+ )
+
+ self.expand_conv = ConvModule(
+ branch_channels,
+ mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.depthwise_conv = ConvModule(
+ mid_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=mid_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ self.linear_conv = ConvModule(
+ mid_channels,
+ branch_channels
+ if stem_channels == self.out_channels else stem_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ x = self.conv1(x)
+ x1, x2 = x.chunk(2, dim=1)
+
+ x2 = self.expand_conv(x2)
+ x2 = self.depthwise_conv(x2)
+ x2 = self.linear_conv(x2)
+
+ out = torch.cat((self.branch1(x1), x2), dim=1)
+
+ out = channel_shuffle(out, 2)
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class IterativeHead(BaseModule):
+ """Extra iterative head for feature learning.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, in_channels, norm_cfg=dict(type='BN'), init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ projects = []
+ num_branchs = len(in_channels)
+ self.in_channels = in_channels[::-1]
+
+ for i in range(num_branchs):
+ if i != num_branchs - 1:
+ projects.append(
+ DepthwiseSeparableConvModule(
+ in_channels=self.in_channels[i],
+ out_channels=self.in_channels[i + 1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ dw_act_cfg=None,
+ pw_act_cfg=dict(type='ReLU')))
+ else:
+ projects.append(
+ DepthwiseSeparableConvModule(
+ in_channels=self.in_channels[i],
+ out_channels=self.in_channels[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ dw_act_cfg=None,
+ pw_act_cfg=dict(type='ReLU')))
+ self.projects = nn.ModuleList(projects)
+
+ def forward(self, x):
+ x = x[::-1]
+
+ y = []
+ last_x = None
+ for i, s in enumerate(x):
+ if last_x is not None:
+ last_x = F.interpolate(
+ last_x,
+ size=s.size()[-2:],
+ mode='bilinear',
+ align_corners=True)
+ s = s + last_x
+ s = self.projects[i](s)
+ y.append(s)
+ last_x = s
+
+ return y[::-1]
+
+
+class ShuffleUnit(BaseModule):
+ """InvertedResidual block for ShuffleNetV2 backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ out_channels (int): The output channels of the block.
+ stride (int): Stride of the 3x3 convolution layer. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.stride = stride
+ self.with_cp = with_cp
+
+ branch_features = out_channels // 2
+ if self.stride == 1:
+ assert in_channels == branch_features * 2, (
+ f'in_channels ({in_channels}) should equal to '
+ f'branch_features * 2 ({branch_features * 2}) '
+ 'when stride is 1')
+
+ if in_channels != branch_features * 2:
+ assert self.stride != 1, (
+ f'stride ({self.stride}) should not equal 1 when '
+ f'in_channels != branch_features * 2')
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ groups=in_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None),
+ ConvModule(
+ in_channels,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ )
+
+ self.branch2 = nn.Sequential(
+ ConvModule(
+ in_channels if (self.stride > 1) else branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ branch_features,
+ branch_features,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ groups=branch_features,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None),
+ ConvModule(
+ branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.stride > 1:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+ else:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+
+ out = channel_shuffle(out, 2)
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class LiteHRModule(BaseModule):
+ """High-Resolution Module for LiteHRNet.
+
+ It contains conditional channel weighting blocks and
+ shuffle blocks.
+
+
+ Args:
+ num_branches (int): Number of branches in the module.
+ num_blocks (int): Number of blocks in the module.
+ in_channels (list(int)): Number of input image channels.
+ reduce_ratio (int): Channel reduction ratio.
+ module_type (str): 'LITE' or 'NAIVE'
+ multiscale_output (bool): Whether to output multi-scale features.
+ with_fuse (bool): Whether to use fuse layers.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_branches,
+ num_blocks,
+ in_channels,
+ reduce_ratio,
+ module_type,
+ multiscale_output=False,
+ with_fuse=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self._check_branches(num_branches, in_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.module_type = module_type
+ self.multiscale_output = multiscale_output
+ self.with_fuse = with_fuse
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+
+ if self.module_type.upper() == 'LITE':
+ self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio)
+ elif self.module_type.upper() == 'NAIVE':
+ self.layers = self._make_naive_branches(num_branches, num_blocks)
+ else:
+ raise ValueError("module_type should be either 'LITE' or 'NAIVE'.")
+ if self.with_fuse:
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU()
+
+ def _check_branches(self, num_branches, in_channels):
+ """Check input to avoid ValueError."""
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1):
+ """Make channel weighting blocks."""
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ ConditionalChannelWeighting(
+ self.in_channels,
+ stride=stride,
+ reduce_ratio=reduce_ratio,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ with_cp=self.with_cp))
+
+ return nn.Sequential(*layers)
+
+ def _make_one_branch(self, branch_index, num_blocks, stride=1):
+ """Make one branch."""
+ layers = []
+ layers.append(
+ ShuffleUnit(
+ self.in_channels[branch_index],
+ self.in_channels[branch_index],
+ stride=stride,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ with_cp=self.with_cp))
+ for i in range(1, num_blocks):
+ layers.append(
+ ShuffleUnit(
+ self.in_channels[branch_index],
+ self.in_channels[branch_index],
+ stride=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ with_cp=self.with_cp))
+
+ return nn.Sequential(*layers)
+
+ def _make_naive_branches(self, num_branches, num_blocks):
+ """Make branches."""
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(self._make_one_branch(i, num_blocks))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ """Make fuse layer."""
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=in_channels[j],
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=in_channels[j],
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=True)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.layers[0](x[0])]
+
+ if self.module_type.upper() == 'LITE':
+ out = self.layers(x)
+ elif self.module_type.upper() == 'NAIVE':
+ for i in range(self.num_branches):
+ x[i] = self.layers[i](x[i])
+ out = x
+
+ if self.with_fuse:
+ out_fuse = []
+ for i in range(len(self.fuse_layers)):
+ # `y = 0` will lead to decreased accuracy (0.5~1 mAP)
+ y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
+ for j in range(self.num_branches):
+ if i == j:
+ y += out[j]
+ else:
+ y += self.fuse_layers[i][j](out[j])
+ out_fuse.append(self.relu(y))
+ out = out_fuse
+ if not self.multiscale_output:
+ out = [out[0]]
+ return out
+
+
+@MODELS.register_module()
+class LiteHRNet(BaseBackbone):
+ """Lite-HRNet backbone.
+
+ `Lite-HRNet: A Lightweight High-Resolution Network
+ `_.
+
+ Code adapted from 'https://github.com/HRNet/Lite-HRNet'.
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Default: 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import LiteHRNet
+ >>> import torch
+ >>> extra=dict(
+ >>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
+ >>> num_stages=3,
+ >>> stages_spec=dict(
+ >>> num_modules=(2, 4, 2),
+ >>> num_branches=(2, 3, 4),
+ >>> num_blocks=(2, 2, 2),
+ >>> module_type=('LITE', 'LITE', 'LITE'),
+ >>> with_fuse=(True, True, True),
+ >>> reduce_ratios=(8, 8, 8),
+ >>> num_channels=(
+ >>> (40, 80),
+ >>> (40, 80, 160),
+ >>> (40, 80, 160, 320),
+ >>> )),
+ >>> with_head=False)
+ >>> self = LiteHRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 40, 8, 8)
+ """
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super().__init__(init_cfg=init_cfg)
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.stem = Stem(
+ in_channels,
+ stem_channels=self.extra['stem']['stem_channels'],
+ out_channels=self.extra['stem']['out_channels'],
+ expand_ratio=self.extra['stem']['expand_ratio'],
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ self.num_stages = self.extra['num_stages']
+ self.stages_spec = self.extra['stages_spec']
+
+ num_channels_last = [
+ self.stem.out_channels,
+ ]
+ for i in range(self.num_stages):
+ num_channels = self.stages_spec['num_channels'][i]
+ num_channels = [num_channels[i] for i in range(len(num_channels))]
+ setattr(
+ self, f'transition{i}',
+ self._make_transition_layer(num_channels_last, num_channels))
+
+ stage, num_channels_last = self._make_stage(
+ self.stages_spec, i, num_channels, multiscale_output=True)
+ setattr(self, f'stage{i}', stage)
+
+ self.with_head = self.extra['with_head']
+ if self.with_head:
+ self.head_layer = IterativeHead(
+ in_channels=num_channels_last,
+ norm_cfg=self.norm_cfg,
+ )
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_pre_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=num_channels_pre_layer[i],
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_pre_layer[i])[1],
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU()))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=in_channels,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels)[1],
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU()))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_stage(self,
+ stages_spec,
+ stage_index,
+ in_channels,
+ multiscale_output=True):
+ num_modules = stages_spec['num_modules'][stage_index]
+ num_branches = stages_spec['num_branches'][stage_index]
+ num_blocks = stages_spec['num_blocks'][stage_index]
+ reduce_ratio = stages_spec['reduce_ratios'][stage_index]
+ with_fuse = stages_spec['with_fuse'][stage_index]
+ module_type = stages_spec['module_type'][stage_index]
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ modules.append(
+ LiteHRModule(
+ num_branches,
+ num_blocks,
+ in_channels,
+ reduce_ratio,
+ module_type,
+ multiscale_output=reset_multiscale_output,
+ with_fuse=with_fuse,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ with_cp=self.with_cp))
+ in_channels = modules[-1].in_channels
+
+ return nn.Sequential(*modules), in_channels
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.stem(x)
+
+ y_list = [x]
+ for i in range(self.num_stages):
+ x_list = []
+ transition = getattr(self, f'transition{i}')
+ for j in range(self.stages_spec['num_branches'][i]):
+ if transition[j]:
+ if j >= len(y_list):
+ x_list.append(transition[j](y_list[-1]))
+ else:
+ x_list.append(transition[j](y_list[j]))
+ else:
+ x_list.append(y_list[j])
+ y_list = getattr(self, f'stage{i}')(x_list)
+
+ x = y_list
+ if self.with_head:
+ x = self.head_layer(x)
+
+ return (x[0], )
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/mobilenet_v2.py b/mmpose/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64c0d73d41d3763018a8e46621c6ab695be6856
--- /dev/null
+++ b/mmpose/models/backbones/mobilenet_v2.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import make_divisible
+
+
+class InvertedResidual(BaseModule):
+ """InvertedResidual block for MobileNetV2.
+
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ return self.conv(x)
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+@MODELS.register_module()
+class MobileNetV2(BaseBackbone):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ # Parameters to build layers. 4 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks, stride.
+ arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
+ [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
+ [6, 320, 1, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ out_indices=(7, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.widen_factor = widen_factor
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 8):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+
+ if frozen_stages not in range(-1, 8):
+ raise ValueError('frozen_stages must be in range(-1, 8). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks, stride = layer_cfg
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ if widen_factor > 1.0:
+ self.out_channel = int(1280 * widen_factor)
+ else:
+ self.out_channel = 1280
+
+ layer = ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.add_module('conv2', layer)
+ self.layers.append('conv2')
+
+ def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio. Default: 6.
+ """
+ layers = []
+ for i in range(num_blocks):
+ if i >= 1:
+ stride = 1
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride,
+ expand_ratio=expand_ratio,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/mobilenet_v3.py b/mmpose/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..03ecf90dd22d42a3650a4eac00c070ec556c7912
--- /dev/null
+++ b/mmpose/models/backbones/mobilenet_v3.py
@@ -0,0 +1,185 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from mmcv.cnn import ConvModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import InvertedResidual
+
+
+@MODELS.register_module()
+class MobileNetV3(BaseBackbone):
+ """MobileNetV3 backbone.
+
+ Args:
+ arch (str): Architecture of mobilnetv3, from {small, big}.
+ Default: small.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (-1, ), which means output tensors from final stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm'])
+ ]``
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2],
+ [3, 72, 24, False, 'ReLU', 2],
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1],
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2],
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'big': [[3, 16, 16, False, 'ReLU', 1],
+ [3, 64, 24, False, 'ReLU', 2],
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2],
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2],
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1],
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ out_indices=(-1, ),
+ frozen_stages=-1,
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ assert arch in self.arch_settings
+ for index in out_indices:
+ if index not in range(-len(self.arch_settings[arch]),
+ len(self.arch_settings[arch])):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch])}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(-1, len(self.arch_settings[arch])):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch])}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = 16
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='HSwish'))
+
+ self.layers = self._make_layer()
+ self.feat_dim = self.arch_settings[arch][-1][2]
+
+ def _make_layer(self):
+ layers = []
+ layer_setting = self.arch_settings[self.arch]
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=1.0, divisor=2.0)))
+ else:
+ se_cfg = None
+
+ layer = InvertedResidual(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=True,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ self.in_channels = out_channels
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+ return layers
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices or \
+ i - len(self.layers) in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/mspn.py b/mmpose/models/backbones/mspn.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcb636b1a3fdc0357fa7dc7c3751738914d58980
--- /dev/null
+++ b/mmpose/models/backbones/mspn.py
@@ -0,0 +1,541 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy as cp
+from collections import OrderedDict
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, MaxPool2d
+from mmengine.model import BaseModule
+from mmengine.runner import load_state_dict
+
+from mmpose.registry import MODELS
+from mmpose.utils import get_root_logger
+from .base_backbone import BaseBackbone
+from .resnet import Bottleneck as _Bottleneck
+from .utils import get_state_dict
+
+
+class Bottleneck(_Bottleneck):
+ expansion = 4
+ """Bottleneck block for MSPN.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ stride (int): stride of the block. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super().__init__(in_channels, out_channels * 4, **kwargs)
+
+
+class DownsampleModule(BaseModule):
+ """Downsample module for MSPN.
+
+ Args:
+ block (nn.Module): Downsample block.
+ num_blocks (list): Number of blocks in each downsample unit.
+ num_units (int): Numbers of downsample units. Default: 4
+ has_skip (bool): Have skip connections from prior upsample
+ module or not. Default:False
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ in_channels (int): Number of channels of the input feature to
+ downsample module. Default: 64
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ num_units=4,
+ has_skip=False,
+ norm_cfg=dict(type='BN'),
+ in_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.has_skip = has_skip
+ self.in_channels = in_channels
+ assert len(num_blocks) == num_units
+ self.num_blocks = num_blocks
+ self.num_units = num_units
+ self.norm_cfg = norm_cfg
+ self.layer1 = self._make_layer(block, in_channels, num_blocks[0])
+ for i in range(1, num_units):
+ module_name = f'layer{i + 1}'
+ self.add_module(
+ module_name,
+ self._make_layer(
+ block, in_channels * pow(2, i), num_blocks[i], stride=2))
+
+ def _make_layer(self, block, out_channels, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
+ downsample = ConvModule(
+ self.in_channels,
+ out_channels * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+
+ units = list()
+ units.append(
+ block(
+ self.in_channels,
+ out_channels,
+ stride=stride,
+ downsample=downsample,
+ norm_cfg=self.norm_cfg))
+ self.in_channels = out_channels * block.expansion
+ for _ in range(1, blocks):
+ units.append(block(self.in_channels, out_channels))
+
+ return nn.Sequential(*units)
+
+ def forward(self, x, skip1, skip2):
+ out = list()
+ for i in range(self.num_units):
+ module_name = f'layer{i + 1}'
+ module_i = getattr(self, module_name)
+ x = module_i(x)
+ if self.has_skip:
+ x = x + skip1[i] + skip2[i]
+ out.append(x)
+ out.reverse()
+
+ return tuple(out)
+
+
+class UpsampleUnit(BaseModule):
+ """Upsample unit for upsample module.
+
+ Args:
+ ind (int): Indicates whether to interpolate (>0) and whether to
+ generate feature map for the next hourglass-like module.
+ num_units (int): Number of units that form a upsample module. Along
+ with ind and gen_cross_conv, nm_units is used to decide whether
+ to generate feature map for the next hourglass-like module.
+ in_channels (int): Channel number of the skip-in feature maps from
+ the corresponding downsample unit.
+ unit_channels (int): Channel number in this unit. Default:256.
+ gen_skip: (bool): Whether or not to generate skips for the posterior
+ downsample module. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ out_channels (int): Number of channels of feature output by upsample
+ module. Must equal to in_channels of downsample module. Default:64
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ ind,
+ num_units,
+ in_channels,
+ unit_channels=256,
+ gen_skip=False,
+ gen_cross_conv=False,
+ norm_cfg=dict(type='BN'),
+ out_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.num_units = num_units
+ self.norm_cfg = norm_cfg
+ self.in_skip = ConvModule(
+ in_channels,
+ unit_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.ind = ind
+ if self.ind > 0:
+ self.up_conv = ConvModule(
+ unit_channels,
+ unit_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+
+ self.gen_skip = gen_skip
+ if self.gen_skip:
+ self.out_skip1 = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ self.out_skip2 = ConvModule(
+ unit_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ self.gen_cross_conv = gen_cross_conv
+ if self.ind == num_units - 1 and self.gen_cross_conv:
+ self.cross_conv = ConvModule(
+ unit_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ def forward(self, x, up_x):
+ out = self.in_skip(x)
+
+ if self.ind > 0:
+ up_x = F.interpolate(
+ up_x,
+ size=(x.size(2), x.size(3)),
+ mode='bilinear',
+ align_corners=True)
+ up_x = self.up_conv(up_x)
+ out = out + up_x
+ out = self.relu(out)
+
+ skip1 = None
+ skip2 = None
+ if self.gen_skip:
+ skip1 = self.out_skip1(x)
+ skip2 = self.out_skip2(out)
+
+ cross_conv = None
+ if self.ind == self.num_units - 1 and self.gen_cross_conv:
+ cross_conv = self.cross_conv(out)
+
+ return out, skip1, skip2, cross_conv
+
+
+class UpsampleModule(BaseModule):
+ """Upsample module for MSPN.
+
+ Args:
+ unit_channels (int): Channel number in the upsample units.
+ Default:256.
+ num_units (int): Numbers of upsample units. Default: 4
+ gen_skip (bool): Whether to generate skip for posterior downsample
+ module or not. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ out_channels (int): Number of channels of feature output by upsample
+ module. Must equal to in_channels of downsample module. Default:64
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ unit_channels=256,
+ num_units=4,
+ gen_skip=False,
+ gen_cross_conv=False,
+ norm_cfg=dict(type='BN'),
+ out_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = list()
+ for i in range(num_units):
+ self.in_channels.append(Bottleneck.expansion * out_channels *
+ pow(2, i))
+ self.in_channels.reverse()
+ self.num_units = num_units
+ self.gen_skip = gen_skip
+ self.gen_cross_conv = gen_cross_conv
+ self.norm_cfg = norm_cfg
+ for i in range(num_units):
+ module_name = f'up{i + 1}'
+ self.add_module(
+ module_name,
+ UpsampleUnit(
+ i,
+ self.num_units,
+ self.in_channels[i],
+ unit_channels,
+ self.gen_skip,
+ self.gen_cross_conv,
+ norm_cfg=self.norm_cfg,
+ out_channels=64))
+
+ def forward(self, x):
+ out = list()
+ skip1 = list()
+ skip2 = list()
+ cross_conv = None
+ for i in range(self.num_units):
+ module_i = getattr(self, f'up{i + 1}')
+ if i == 0:
+ outi, skip1_i, skip2_i, _ = module_i(x[i], None)
+ elif i == self.num_units - 1:
+ outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
+ else:
+ outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
+ out.append(outi)
+ skip1.append(skip1_i)
+ skip2.append(skip2_i)
+ skip1.reverse()
+ skip2.reverse()
+
+ return out, skip1, skip2, cross_conv
+
+
+class SingleStageNetwork(BaseModule):
+ """Single_stage Network.
+
+ Args:
+ unit_channels (int): Channel number in the upsample units. Default:256.
+ num_units (int): Numbers of downsample/upsample units. Default: 4
+ gen_skip (bool): Whether to generate skip for posterior downsample
+ module or not. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ has_skip (bool): Have skip connections from prior upsample
+ module or not. Default:False
+ num_blocks (list): Number of blocks in each downsample unit.
+ Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ in_channels (int): Number of channels of the feature from ResNetTop.
+ Default: 64.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ has_skip=False,
+ gen_skip=False,
+ gen_cross_conv=False,
+ unit_channels=256,
+ num_units=4,
+ num_blocks=[2, 2, 2, 2],
+ norm_cfg=dict(type='BN'),
+ in_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ num_blocks = cp.deepcopy(num_blocks)
+ super().__init__(init_cfg=init_cfg)
+ assert len(num_blocks) == num_units
+ self.has_skip = has_skip
+ self.gen_skip = gen_skip
+ self.gen_cross_conv = gen_cross_conv
+ self.num_units = num_units
+ self.unit_channels = unit_channels
+ self.num_blocks = num_blocks
+ self.norm_cfg = norm_cfg
+
+ self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units,
+ has_skip, norm_cfg, in_channels)
+ self.upsample = UpsampleModule(unit_channels, num_units, gen_skip,
+ gen_cross_conv, norm_cfg, in_channels)
+
+ def forward(self, x, skip1, skip2):
+ mid = self.downsample(x, skip1, skip2)
+ out, skip1, skip2, cross_conv = self.upsample(mid)
+
+ return out, skip1, skip2, cross_conv
+
+
+class ResNetTop(BaseModule):
+ """ResNet top for MSPN.
+
+ Args:
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ channels (int): Number of channels of the feature output by ResNetTop.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, norm_cfg=dict(type='BN'), channels=64, init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.top = nn.Sequential(
+ ConvModule(
+ 3,
+ channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ norm_cfg=norm_cfg,
+ inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
+
+ def forward(self, img):
+ return self.top(img)
+
+
+@MODELS.register_module()
+class MSPN(BaseBackbone):
+ """MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks
+ for Human Pose Estimation" (CVPR 2020).
+
+ Args:
+ unit_channels (int): Number of Channels in an upsample unit.
+ Default: 256
+ num_stages (int): Number of stages in a multi-stage MSPN. Default: 4
+ num_units (int): Number of downsample/upsample units in a single-stage
+ network. Default: 4
+ Note: Make sure num_units == len(self.num_blocks)
+ num_blocks (list): Number of bottlenecks in each
+ downsample unit. Default: [2, 2, 2, 2]
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ res_top_channels (int): Number of channels of feature from ResNetTop.
+ Default: 64.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(
+ type='Normal',
+ std=0.01,
+ layer=['Linear']),
+ ]``
+
+ Example:
+ >>> from mmpose.models import MSPN
+ >>> import torch
+ >>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2])
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 511, 511)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... for feature in level_output:
+ ... print(tuple(feature.shape))
+ ...
+ (1, 256, 64, 64)
+ (1, 256, 128, 128)
+ (1, 256, 64, 64)
+ (1, 256, 128, 128)
+ """
+
+ def __init__(self,
+ unit_channels=256,
+ num_stages=4,
+ num_units=4,
+ num_blocks=[2, 2, 2, 2],
+ norm_cfg=dict(type='BN'),
+ res_top_channels=64,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(type='Normal', std=0.01, layer=['Linear']),
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ num_blocks = cp.deepcopy(num_blocks)
+ super().__init__(init_cfg=init_cfg)
+ self.unit_channels = unit_channels
+ self.num_stages = num_stages
+ self.num_units = num_units
+ self.num_blocks = num_blocks
+ self.norm_cfg = norm_cfg
+
+ assert self.num_stages > 0
+ assert self.num_units > 1
+ assert self.num_units == len(self.num_blocks)
+ self.top = ResNetTop(norm_cfg=norm_cfg)
+ self.multi_stage_mspn = nn.ModuleList([])
+ for i in range(self.num_stages):
+ if i == 0:
+ has_skip = False
+ else:
+ has_skip = True
+ if i != self.num_stages - 1:
+ gen_skip = True
+ gen_cross_conv = True
+ else:
+ gen_skip = False
+ gen_cross_conv = False
+ self.multi_stage_mspn.append(
+ SingleStageNetwork(has_skip, gen_skip, gen_cross_conv,
+ unit_channels, num_units, num_blocks,
+ norm_cfg, res_top_channels))
+
+ def forward(self, x):
+ """Model forward function."""
+ out_feats = []
+ skip1 = None
+ skip2 = None
+ x = self.top(x)
+ for i in range(self.num_stages):
+ out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2)
+ out_feats.append(out)
+
+ return out_feats
+
+ def init_weights(self):
+ """Initialize model weights."""
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ logger = get_root_logger()
+ state_dict_tmp = get_state_dict(self.init_cfg['checkpoint'])
+ state_dict = OrderedDict()
+ state_dict['top'] = OrderedDict()
+ state_dict['bottlenecks'] = OrderedDict()
+ for k, v in state_dict_tmp.items():
+ if k.startswith('layer'):
+ if 'downsample.0' in k:
+ state_dict['bottlenecks'][k.replace(
+ 'downsample.0', 'downsample.conv')] = v
+ elif 'downsample.1' in k:
+ state_dict['bottlenecks'][k.replace(
+ 'downsample.1', 'downsample.bn')] = v
+ else:
+ state_dict['bottlenecks'][k] = v
+ elif k.startswith('conv1'):
+ state_dict['top'][k.replace('conv1', 'top.0.conv')] = v
+ elif k.startswith('bn1'):
+ state_dict['top'][k.replace('bn1', 'top.0.bn')] = v
+
+ load_state_dict(
+ self.top, state_dict['top'], strict=False, logger=logger)
+ for i in range(self.num_stages):
+ load_state_dict(
+ self.multi_stage_mspn[i].downsample,
+ state_dict['bottlenecks'],
+ strict=False,
+ logger=logger)
+ else:
+ super(MSPN, self).init_weights()
diff --git a/mmpose/models/backbones/pvt.py b/mmpose/models/backbones/pvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f2b6495482b4feadd86f51fa11b64ee10878fef
--- /dev/null
+++ b/mmpose/models/backbones/pvt.py
@@ -0,0 +1,569 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.transformer import MultiheadAttention
+from mmengine.model import BaseModule, ModuleList, Sequential
+from mmengine.model.weight_init import trunc_normal_
+from mmengine.runner import load_state_dict
+from mmengine.utils import to_2tuple
+
+from mmpose.registry import MODELS
+from ...utils import get_root_logger
+from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert
+from .utils import get_state_dict
+
+
+class MixFFN(BaseModule):
+ """An implementation of MixFFN of PVT.
+
+ The differences between MixFFN & FFN:
+ 1. Use 1X1 Conv to replace Linear layer.
+ 2. Introduce 3X3 Depth-wise Conv to encode positional information.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='GELU').
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ Default: None.
+ use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
+ Defaults: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ feedforward_channels,
+ act_cfg=dict(type='GELU'),
+ ffn_drop=0.,
+ dropout_layer=None,
+ use_conv=False,
+ init_cfg=None):
+ super(MixFFN, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.act_cfg = act_cfg
+ activate = build_activation_layer(act_cfg)
+
+ in_channels = embed_dims
+ fc1 = Conv2d(
+ in_channels=in_channels,
+ out_channels=feedforward_channels,
+ kernel_size=1,
+ stride=1,
+ bias=True)
+ if use_conv:
+ # 3x3 depth wise conv to provide positional encode information
+ dw_conv = Conv2d(
+ in_channels=feedforward_channels,
+ out_channels=feedforward_channels,
+ kernel_size=3,
+ stride=1,
+ padding=(3 - 1) // 2,
+ bias=True,
+ groups=feedforward_channels)
+ fc2 = Conv2d(
+ in_channels=feedforward_channels,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ bias=True)
+ drop = nn.Dropout(ffn_drop)
+ layers = [fc1, activate, drop, fc2, drop]
+ if use_conv:
+ layers.insert(1, dw_conv)
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+
+ def forward(self, x, hw_shape, identity=None):
+ out = nlc_to_nchw(x, hw_shape)
+ out = self.layers(out)
+ out = nchw_to_nlc(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+class SpatialReductionAttention(MultiheadAttention):
+ """An implementation of Spatial Reduction Attention of PVT.
+
+ This module is modified from MultiheadAttention which is a module from
+ mmcv.cnn.bricks.transformer.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut. Default: None.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default: False.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention of PVT. Default: 1.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=None,
+ batch_first=True,
+ qkv_bias=True,
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ init_cfg=None):
+ super().__init__(
+ embed_dims,
+ num_heads,
+ attn_drop,
+ proj_drop,
+ batch_first=batch_first,
+ dropout_layer=dropout_layer,
+ bias=qkv_bias,
+ init_cfg=init_cfg)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = Conv2d(
+ in_channels=embed_dims,
+ out_channels=embed_dims,
+ kernel_size=sr_ratio,
+ stride=sr_ratio)
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
+ from mmpose import digit_version, mmcv_version
+ if mmcv_version < digit_version('1.3.17'):
+ warnings.warn('The legacy version of forward function in'
+ 'SpatialReductionAttention is deprecated in'
+ 'mmcv>=1.3.17 and will no longer support in the'
+ 'future. Please upgrade your mmcv.')
+ self.forward = self.legacy_forward
+
+ def forward(self, x, hw_shape, identity=None):
+
+ x_q = x
+ if self.sr_ratio > 1:
+ x_kv = nlc_to_nchw(x, hw_shape)
+ x_kv = self.sr(x_kv)
+ x_kv = nchw_to_nlc(x_kv)
+ x_kv = self.norm(x_kv)
+ else:
+ x_kv = x
+
+ if identity is None:
+ identity = x_q
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ x_q = x_q.transpose(0, 1)
+ x_kv = x_kv.transpose(0, 1)
+
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+ def legacy_forward(self, x, hw_shape, identity=None):
+ """multi head attention forward in mmcv version < 1.3.17."""
+ x_q = x
+ if self.sr_ratio > 1:
+ x_kv = nlc_to_nchw(x, hw_shape)
+ x_kv = self.sr(x_kv)
+ x_kv = nchw_to_nlc(x_kv)
+ x_kv = self.norm(x_kv)
+ else:
+ x_kv = x
+
+ if identity is None:
+ identity = x_q
+
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+class PVTEncoderLayer(BaseModule):
+ """Implements one encoder layer in PVT.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ drop_rate (float): Probability of an element to be zeroed.
+ after the feed forward layer. Default: 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default: 0.0.
+ drop_path_rate (float): stochastic depth rate. Default: 0.0.
+ qkv_bias (bool): enable bias for qkv if True.
+ Default: True.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention of PVT. Default: 1.
+ use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ qkv_bias=True,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ use_conv_ffn=False,
+ init_cfg=None):
+ super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
+
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ self.attn = SpatialReductionAttention(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ attn_drop=attn_drop_rate,
+ proj_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ qkv_bias=qkv_bias,
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratio)
+
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ self.ffn = MixFFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ use_conv=use_conv_ffn,
+ act_cfg=act_cfg)
+
+ def forward(self, x, hw_shape):
+ x = self.attn(self.norm1(x), hw_shape, identity=x)
+ x = self.ffn(self.norm2(x), hw_shape, identity=x)
+
+ return x
+
+
+class AbsolutePositionEmbedding(BaseModule):
+ """An implementation of the absolute position embedding in PVT.
+
+ Args:
+ pos_shape (int): The shape of the absolute position embedding.
+ pos_dim (int): The dimension of the absolute position embedding.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default: 0.0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(pos_shape, int):
+ pos_shape = to_2tuple(pos_shape)
+ elif isinstance(pos_shape, tuple):
+ if len(pos_shape) == 1:
+ pos_shape = to_2tuple(pos_shape[0])
+ assert len(pos_shape) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pos_shape)}'
+ self.pos_shape = pos_shape
+ self.pos_dim = pos_dim
+
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
+ self.drop = nn.Dropout(p=drop_rate)
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bilinear interpolate method.
+
+ Args:
+ pos_embed (torch.Tensor): Position embedding weights.
+ input_shape (tuple): Tuple for (downsampled input image height,
+ downsampled input image width).
+ mode (str): Algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'``. Default: ``'bilinear'``.
+
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C].
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ pos_h, pos_w = self.pos_shape
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight, size=input_shape, mode=mode)
+ pos_embed_weight = torch.flatten(pos_embed_weight,
+ 2).transpose(1, 2).contiguous()
+ pos_embed = pos_embed_weight
+
+ return pos_embed
+
+ def forward(self, x, hw_shape, mode='bilinear'):
+ pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
+ return self.drop(x + pos_embed)
+
+
+@MODELS.register_module()
+class PyramidVisionTransformer(BaseModule):
+ """Pyramid Vision Transformer (PVT)
+
+ Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
+ Dense Prediction without Convolutions
+ `_.
+
+ Args:
+ pretrain_img_size (int | tuple[int]): The size of input image when
+ pretrain. Defaults: 224.
+ in_channels (int): Number of input channels. Default: 3.
+ embed_dims (int): Embedding dimension. Default: 64.
+ num_stags (int): The num of stages. Default: 4.
+ num_layers (Sequence[int]): The layer number of each transformer encode
+ layer. Default: [3, 4, 6, 3].
+ num_heads (Sequence[int]): The attention heads of each transformer
+ encode layer. Default: [1, 2, 5, 8].
+ patch_sizes (Sequence[int]): The patch_size of each patch embedding.
+ Default: [4, 2, 2, 2].
+ strides (Sequence[int]): The stride of each patch embedding.
+ Default: [4, 2, 2, 2].
+ paddings (Sequence[int]): The padding of each patch embedding.
+ Default: [0, 0, 0, 0].
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
+ transformer encode layer. Default: [8, 4, 2, 1].
+ out_indices (Sequence[int] | int): Output from which stages.
+ Default: (0, 1, 2, 3).
+ mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
+ embedding dim of each transformer encode layer.
+ Default: [8, 8, 4, 4].
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0.
+ drop_path_rate (float): stochastic depth rate. Default 0.1.
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
+ the patch embedding. Defaults: True.
+ use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
+ Default: False.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ pretrained (str, optional): model pretrained path. Default: None.
+ convert_weights (bool): The flag indicates whether the
+ pre-trained model is from the original repo. We may need
+ to convert some keys to make it compatible.
+ Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='TruncNormal', std=.02, layer=['Linear']),
+ dict(type='Constant', val=1, layer=['LayerNorm']),
+ dict(type='Normal', std=0.01, layer=['Conv2d'])
+ ]``
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ in_channels=3,
+ embed_dims=64,
+ num_stages=4,
+ num_layers=[3, 4, 6, 3],
+ num_heads=[1, 2, 5, 8],
+ patch_sizes=[4, 2, 2, 2],
+ strides=[4, 2, 2, 2],
+ paddings=[0, 0, 0, 0],
+ sr_ratios=[8, 4, 2, 1],
+ out_indices=(0, 1, 2, 3),
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ use_abs_pos_embed=True,
+ norm_after_stage=False,
+ use_conv_ffn=False,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ convert_weights=True,
+ init_cfg=[
+ dict(type='TruncNormal', std=.02, layer=['Linear']),
+ dict(type='Constant', val=1, layer=['LayerNorm']),
+ dict(type='Kaiming', layer=['Conv2d'])
+ ]):
+ super().__init__(init_cfg=init_cfg)
+
+ self.convert_weights = convert_weights
+ if isinstance(pretrain_img_size, int):
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ elif isinstance(pretrain_img_size, tuple):
+ if len(pretrain_img_size) == 1:
+ pretrain_img_size = to_2tuple(pretrain_img_size[0])
+ assert len(pretrain_img_size) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pretrain_img_size)}'
+
+ self.embed_dims = embed_dims
+
+ self.num_stages = num_stages
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.patch_sizes = patch_sizes
+ self.strides = strides
+ self.sr_ratios = sr_ratios
+ assert num_stages == len(num_layers) == len(num_heads) \
+ == len(patch_sizes) == len(strides) == len(sr_ratios)
+
+ self.out_indices = out_indices
+ assert max(out_indices) < self.num_stages
+
+ # transformer encoder
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(num_layers))
+ ] # stochastic num_layer decay rule
+
+ cur = 0
+ self.layers = ModuleList()
+ for i, num_layer in enumerate(num_layers):
+ embed_dims_i = embed_dims * num_heads[i]
+ patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims_i,
+ kernel_size=patch_sizes[i],
+ stride=strides[i],
+ padding=paddings[i],
+ bias=True,
+ norm_cfg=norm_cfg)
+
+ layers = ModuleList()
+ if use_abs_pos_embed:
+ pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
+ pos_embed = AbsolutePositionEmbedding(
+ pos_shape=pos_shape,
+ pos_dim=embed_dims_i,
+ drop_rate=drop_rate)
+ layers.append(pos_embed)
+ layers.extend([
+ PVTEncoderLayer(
+ embed_dims=embed_dims_i,
+ num_heads=num_heads[i],
+ feedforward_channels=mlp_ratios[i] * embed_dims_i,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[cur + idx],
+ qkv_bias=qkv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratios[i],
+ use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
+ ])
+ in_channels = embed_dims_i
+ # The ret[0] of build_norm_layer is norm name.
+ if norm_after_stage:
+ norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
+ else:
+ norm = nn.Identity()
+ self.layers.append(ModuleList([patch_embed, layers, norm]))
+ cur += num_layer
+
+ def init_weights(self):
+ """Initialize the weights in backbone."""
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ logger = get_root_logger()
+ state_dict = get_state_dict(
+ self.init_cfg['checkpoint'], map_location='cpu')
+ logger.warn(f'Load pre-trained model for '
+ f'{self.__class__.__name__} from original repo')
+
+ if self.convert_weights:
+ # Because pvt backbones are not supported by mmcls,
+ # so we need to convert pre-trained weights to match this
+ # implementation.
+ state_dict = pvt_convert(state_dict)
+ load_state_dict(self, state_dict, strict=False, logger=logger)
+
+ else:
+ super(PyramidVisionTransformer, self).init_weights()
+
+ def forward(self, x):
+ outs = []
+
+ for i, layer in enumerate(self.layers):
+ x, hw_shape = layer[0](x)
+
+ for block in layer[1]:
+ x = block(x, hw_shape)
+ x = layer[2](x)
+ x = nlc_to_nchw(x, hw_shape)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return outs
+
+
+@MODELS.register_module()
+class PyramidVisionTransformerV2(PyramidVisionTransformer):
+ """Implementation of `PVTv2: Improved Baselines with Pyramid Vision
+ Transformer `_."""
+
+ def __init__(self, **kwargs):
+ super(PyramidVisionTransformerV2, self).__init__(
+ patch_sizes=[7, 3, 3, 3],
+ paddings=[3, 1, 1, 1],
+ use_abs_pos_embed=False,
+ norm_after_stage=True,
+ use_conv_ffn=True,
+ **kwargs)
diff --git a/mmpose/models/backbones/regnet.py b/mmpose/models/backbones/regnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..120523e658ecb2b3134eba45508ac47457a87f1d
--- /dev/null
+++ b/mmpose/models/backbones/regnet.py
@@ -0,0 +1,331 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import numpy as np
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpose.registry import MODELS
+from .resnet import ResNet
+from .resnext import Bottleneck
+
+
+@MODELS.register_module()
+class RegNet(ResNet):
+ """RegNet backbone.
+
+ More details can be found in `paper `__ .
+
+ Args:
+ arch (dict): The parameter of RegNets.
+ - w0 (int): initial width
+ - wa (float): slope of width
+ - wm (float): quantization parameter to quantize the width
+ - depth (int): depth of the backbone
+ - group_w (int): width of group
+ - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ base_channels (int): Base channels after stem layer.
+ in_channels (int): Number of input image channels. Default: 3.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer. Default: "pytorch".
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters. Default: -1.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import RegNet
+ >>> import torch
+ >>> self = RegNet(
+ arch=dict(
+ w0=88,
+ wa=26.31,
+ wm=2.25,
+ group_w=48,
+ depth=25,
+ bot_mul=1.0),
+ out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 96, 8, 8)
+ (1, 192, 4, 4)
+ (1, 432, 2, 2)
+ (1, 1008, 1, 1)
+ """
+ arch_settings = {
+ 'regnetx_400mf':
+ dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
+ 'regnetx_800mf':
+ dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
+ 'regnetx_1.6gf':
+ dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
+ 'regnetx_3.2gf':
+ dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
+ 'regnetx_4.0gf':
+ dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
+ 'regnetx_6.4gf':
+ dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
+ 'regnetx_8.0gf':
+ dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
+ 'regnetx_12gf':
+ dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ stem_channels=32,
+ base_channels=32,
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super(ResNet, self).__init__(init_cfg=init_cfg)
+
+ # Generate RegNet parameters first
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the' \
+ ' arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ widths, num_stages = self.generate_regnet(
+ arch['w0'],
+ arch['wa'],
+ arch['wm'],
+ arch['depth'],
+ )
+ # Convert to per stage format
+ stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
+ # Generate group widths and bot muls
+ group_widths = [arch['group_w'] for _ in range(num_stages)]
+ self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
+ # Adjust the compatibility of stage_widths and group_widths
+ stage_widths, group_widths = self.adjust_width_group(
+ stage_widths, self.bottleneck_ratio, group_widths)
+
+ # Group params by stage
+ self.stage_widths = stage_widths
+ self.group_widths = group_widths
+ self.depth = sum(stage_blocks)
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert 1 <= num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ if self.deep_stem:
+ raise NotImplementedError(
+ 'deep_stem has not been implemented for RegNet')
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.stage_blocks = stage_blocks[:num_stages]
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ _in_channels = stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ group_width = self.group_widths[i]
+ width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
+ stage_groups = width // group_width
+
+ res_layer = self.make_res_layer(
+ block=Bottleneck,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=self.stage_widths[i],
+ expansion=1,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ base_channels=self.stage_widths[i],
+ groups=stage_groups,
+ width_per_group=group_width)
+ _in_channels = self.stage_widths[i]
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = stage_widths[-1]
+
+ def _make_stem_layer(self, in_channels, base_channels):
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ base_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, base_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+
+ @staticmethod
+ def generate_regnet(initial_width,
+ width_slope,
+ width_parameter,
+ depth,
+ divisor=8):
+ """Generates per block width from RegNet parameters.
+
+ Args:
+ initial_width ([int]): Initial width of the backbone
+ width_slope ([float]): Slope of the quantized linear function
+ width_parameter ([int]): Parameter used to quantize the width.
+ depth ([int]): Depth of the backbone.
+ divisor (int, optional): The divisor of channels. Defaults to 8.
+
+ Returns:
+ list, int: return a list of widths of each stage and the number of
+ stages
+ """
+ assert width_slope >= 0
+ assert initial_width > 0
+ assert width_parameter > 1
+ assert initial_width % divisor == 0
+ widths_cont = np.arange(depth) * width_slope + initial_width
+ ks = np.round(
+ np.log(widths_cont / initial_width) / np.log(width_parameter))
+ widths = initial_width * np.power(width_parameter, ks)
+ widths = np.round(np.divide(widths, divisor)) * divisor
+ num_stages = len(np.unique(widths))
+ widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
+ return widths, num_stages
+
+ @staticmethod
+ def quantize_float(number, divisor):
+ """Converts a float to closest non-zero int divisible by divior.
+
+ Args:
+ number (int): Original number to be quantized.
+ divisor (int): Divisor used to quantize the number.
+
+ Returns:
+ int: quantized number that is divisible by devisor.
+ """
+ return int(round(number / divisor) * divisor)
+
+ def adjust_width_group(self, widths, bottleneck_ratio, groups):
+ """Adjusts the compatibility of widths and groups.
+
+ Args:
+ widths (list[int]): Width of each stage.
+ bottleneck_ratio (float): Bottleneck ratio.
+ groups (int): number of groups in each stage
+
+ Returns:
+ tuple(list): The adjusted widths and groups of each stage.
+ """
+ bottleneck_width = [
+ int(w * b) for w, b in zip(widths, bottleneck_ratio)
+ ]
+ groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
+ bottleneck_width = [
+ self.quantize_float(w_bot, g)
+ for w_bot, g in zip(bottleneck_width, groups)
+ ]
+ widths = [
+ int(w_bot / b)
+ for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
+ ]
+ return widths, groups
+
+ def get_stages_from_blocks(self, widths):
+ """Gets widths/stage_blocks of network at each stage.
+
+ Args:
+ widths (list[int]): Width in each stage.
+
+ Returns:
+ tuple(list): width and depth of each stage
+ """
+ width_diff = [
+ width != width_prev
+ for width, width_prev in zip(widths + [0], [0] + widths)
+ ]
+ stage_widths = [
+ width for width, diff in zip(widths, width_diff[:-1]) if diff
+ ]
+ stage_blocks = np.diff([
+ depth for depth, diff in zip(range(len(width_diff)), width_diff)
+ if diff
+ ]).tolist()
+ return stage_widths, stage_blocks
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
diff --git a/mmpose/models/backbones/resnest.py b/mmpose/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5eea8ad7e50c2ab997e2df17316943fcaf3a5fe
--- /dev/null
+++ b/mmpose/models/backbones/resnest.py
@@ -0,0 +1,353 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResLayer, ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(BaseModule):
+ """Split-Attention Conv2d.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ groups=1,
+ width_per_group=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ # For ResNet bottleneck, middle channels are determined by expansion
+ # and out_channels, but for ResNeXt bottleneck, it is determined by
+ # groups and width_per_group and the stage it is located in.
+ if groups != 1:
+ assert self.mid_channels % base_channels == 0
+ self.mid_channels = (
+ groups * width_per_group * self.mid_channels // base_channels)
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = SplitAttentionConv2d(
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@MODELS.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Please refer to the `paper `__
+ for details.
+
+ Args:
+ depth (int): Network depth, from {50, 101, 152, 200}.
+ groups (int): Groups of conv2 in Bottleneck. Default: 32.
+ width_per_group (int): Width per group of conv2 in Bottleneck.
+ Default: 4.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3)),
+ 269: (Bottleneck, (3, 30, 48, 8))
+ }
+
+ def __init__(self,
+ depth,
+ groups=1,
+ width_per_group=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.width_per_group = width_per_group
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super().__init__(depth=depth, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(
+ groups=self.groups,
+ width_per_group=self.width_per_group,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/mmpose/models/backbones/resnet.py b/mmpose/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a04853f60d179ee2450ca199b0a8c28ae893941f
--- /dev/null
+++ b/mmpose/models/backbones/resnet.py
@@ -0,0 +1,715 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule, constant_init
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class BasicBlock(BaseModule):
+ """BasicBlock for ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the output channels of conv1. This is a
+ reserved argument in BasicBlock and should always be 1. Default: 1.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None.
+ style (str): `pytorch` or `caffe`. It is unused and reserved for
+ unified API with Bottleneck.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=1,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert self.expansion == 1
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, out_channels, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(BaseModule):
+ """Bottleneck block for ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the input/output channels of conv2. Default: 4.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None.
+ style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
+ stride-two layer is the 3x3 conv layer, otherwise the stride-two
+ layer is the first 1x1 conv layer. Default: "pytorch".
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=4,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ assert style in ['pytorch', 'caffe']
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: the normalization layer named "norm3" """
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def get_expansion(block, expansion=None):
+ """Get the expansion of a residual block.
+
+ The block expansion will be obtained by the following order:
+
+ 1. If ``expansion`` is given, just return it.
+ 2. If ``block`` has the attribute ``expansion``, then return
+ ``block.expansion``.
+ 3. Return the default value according the the block type:
+ 1 for ``BasicBlock`` and 4 for ``Bottleneck``.
+
+ Args:
+ block (class): The block class.
+ expansion (int | None): The given expansion ratio.
+
+ Returns:
+ int: The expansion of the block.
+ """
+ if isinstance(expansion, int):
+ assert expansion > 0
+ elif expansion is None:
+ if hasattr(block, 'expansion'):
+ expansion = block.expansion
+ elif issubclass(block, BasicBlock):
+ expansion = 1
+ elif issubclass(block, Bottleneck):
+ expansion = 4
+ else:
+ raise TypeError(f'expansion is not specified for {block.__name__}')
+ else:
+ raise TypeError('expansion must be an integer or None')
+
+ return expansion
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): Residual block used to build ResLayer.
+ num_blocks (int): Number of blocks.
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int, optional): The expansion for BasicBlock/Bottleneck.
+ If not specified, it will firstly be obtained via
+ ``block.expansion``. If the block has no attribute "expansion",
+ the following default values will be used: 1 for BasicBlock and
+ 4 for Bottleneck. Default: None.
+ stride (int): stride of the first block. Default: 1.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ in_channels,
+ out_channels,
+ expansion=None,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ **kwargs):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ self.block = block
+ self.expansion = get_expansion(block, expansion)
+
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ in_channels = out_channels
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ else: # downsample_first=False is for HourglassModule
+ for i in range(0, num_blocks - 1):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ super().__init__(*layers)
+
+
+@MODELS.register_module()
+class ResNet(BaseBackbone):
+ """ResNet backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ base_channels (int): Middle channels of the first stage. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18, out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ expansion=None,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super(ResNet, self).__init__(init_cfg)
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert 1 <= num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.expansion = get_expansion(self.block, expansion)
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ _in_channels = stem_channels
+ _out_channels = base_channels * self.expansion
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ res_layer = self.make_res_layer(
+ block=self.block,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=_out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg)
+ _in_channels = _out_channels
+ _out_channels *= 2
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = res_layer[-1].out_channels
+
+ def make_res_layer(self, **kwargs):
+ """Make a ResLayer."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ """Make stem layer."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone."""
+ super(ResNet, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress zero_init_residual if use pretrained model.
+ return
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@MODELS.register_module()
+class ResNetV1d(ResNet):
+ r"""ResNetV1d variant described in `Bag of Tricks
+ `__.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(deep_stem=True, avg_down=True, **kwargs)
diff --git a/mmpose/models/backbones/resnext.py b/mmpose/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..241f83a11449d3e816d4dbb16bd5715cf9ba6e3f
--- /dev/null
+++ b/mmpose/models/backbones/resnext.py
@@ -0,0 +1,171 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpose.registry import MODELS
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResLayer, ResNet
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ base_channels=64,
+ groups=32,
+ width_per_group=4,
+ **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ # For ResNet bottleneck, middle channels are determined by expansion
+ # and out_channels, but for ResNeXt bottleneck, it is determined by
+ # groups and width_per_group and the stage it is located in.
+ if groups != 1:
+ assert self.mid_channels % base_channels == 0
+ self.mid_channels = (
+ groups * width_per_group * self.mid_channels // base_channels)
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@MODELS.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {50, 101, 152}.
+ groups (int): Groups of conv2 in Bottleneck. Default: 32.
+ width_per_group (int): Width per group of conv2 in Bottleneck.
+ Default: 4.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import ResNeXt
+ >>> import torch
+ >>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
+ self.groups = groups
+ self.width_per_group = width_per_group
+ super().__init__(depth, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(
+ groups=self.groups,
+ width_per_group=self.width_per_group,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmpose/models/backbones/rsn.py b/mmpose/models/backbones/rsn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8267d23d952f9639dff524cfea8e8d111ce19584
--- /dev/null
+++ b/mmpose/models/backbones/rsn.py
@@ -0,0 +1,640 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy as cp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, MaxPool2d
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class RSB(BaseModule):
+ """Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate
+ Local Representations for Multi-Person Pose Estimation" (ECCV 2020).
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ num_steps (int): Numbers of steps in RSB
+ stride (int): stride of the block. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ expand_times (int): Times by which the in_channels are expanded.
+ Default:26.
+ res_top_channels (int): Number of channels of feature output by
+ ResNet_top. Default:64.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ expansion = 1
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_steps=4,
+ stride=1,
+ downsample=None,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ expand_times=26,
+ res_top_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ assert num_steps > 1
+ self.in_channels = in_channels
+ self.branch_channels = self.in_channels * expand_times
+ self.branch_channels //= res_top_channels
+ self.out_channels = out_channels
+ self.stride = stride
+ self.downsample = downsample
+ self.with_cp = with_cp
+ self.norm_cfg = norm_cfg
+ self.num_steps = num_steps
+ self.conv_bn_relu1 = ConvModule(
+ self.in_channels,
+ self.num_steps * self.branch_channels,
+ kernel_size=1,
+ stride=self.stride,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=False)
+ for i in range(self.num_steps):
+ for j in range(i + 1):
+ module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
+ self.add_module(
+ module_name,
+ ConvModule(
+ self.branch_channels,
+ self.branch_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+ self.conv_bn3 = ConvModule(
+ self.num_steps * self.branch_channels,
+ self.out_channels * self.expansion,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act_cfg=None,
+ norm_cfg=self.norm_cfg,
+ inplace=False)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ """Forward function."""
+
+ identity = x
+ x = self.conv_bn_relu1(x)
+ spx = torch.split(x, self.branch_channels, 1)
+ outputs = list()
+ outs = list()
+ for i in range(self.num_steps):
+ outputs_i = list()
+ outputs.append(outputs_i)
+ for j in range(i + 1):
+ if j == 0:
+ inputs = spx[i]
+ else:
+ inputs = outputs[i][j - 1]
+ if i > j:
+ inputs = inputs + outputs[i - 1][j]
+ module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
+ module_i_j = getattr(self, module_name)
+ outputs[i].append(module_i_j(inputs))
+
+ outs.append(outputs[i][i])
+ out = torch.cat(tuple(outs), 1)
+ out = self.conv_bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(identity)
+ out = out + identity
+
+ out = self.relu(out)
+
+ return out
+
+
+class Downsample_module(BaseModule):
+ """Downsample module for RSN.
+
+ Args:
+ block (nn.Module): Downsample block.
+ num_blocks (list): Number of blocks in each downsample unit.
+ num_units (int): Numbers of downsample units. Default: 4
+ has_skip (bool): Have skip connections from prior upsample
+ module or not. Default:False
+ num_steps (int): Number of steps in a block. Default:4
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ in_channels (int): Number of channels of the input feature to
+ downsample module. Default: 64
+ expand_times (int): Times by which the in_channels are expanded.
+ Default:26.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ num_steps=4,
+ num_units=4,
+ has_skip=False,
+ norm_cfg=dict(type='BN'),
+ in_channels=64,
+ expand_times=26,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.has_skip = has_skip
+ self.in_channels = in_channels
+ assert len(num_blocks) == num_units
+ self.num_blocks = num_blocks
+ self.num_units = num_units
+ self.num_steps = num_steps
+ self.norm_cfg = norm_cfg
+ self.layer1 = self._make_layer(
+ block,
+ in_channels,
+ num_blocks[0],
+ expand_times=expand_times,
+ res_top_channels=in_channels)
+ for i in range(1, num_units):
+ module_name = f'layer{i + 1}'
+ self.add_module(
+ module_name,
+ self._make_layer(
+ block,
+ in_channels * pow(2, i),
+ num_blocks[i],
+ stride=2,
+ expand_times=expand_times,
+ res_top_channels=in_channels))
+
+ def _make_layer(self,
+ block,
+ out_channels,
+ blocks,
+ stride=1,
+ expand_times=26,
+ res_top_channels=64):
+ downsample = None
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
+ downsample = ConvModule(
+ self.in_channels,
+ out_channels * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+
+ units = list()
+ units.append(
+ block(
+ self.in_channels,
+ out_channels,
+ num_steps=self.num_steps,
+ stride=stride,
+ downsample=downsample,
+ norm_cfg=self.norm_cfg,
+ expand_times=expand_times,
+ res_top_channels=res_top_channels))
+ self.in_channels = out_channels * block.expansion
+ for _ in range(1, blocks):
+ units.append(
+ block(
+ self.in_channels,
+ out_channels,
+ num_steps=self.num_steps,
+ expand_times=expand_times,
+ res_top_channels=res_top_channels))
+
+ return nn.Sequential(*units)
+
+ def forward(self, x, skip1, skip2):
+ out = list()
+ for i in range(self.num_units):
+ module_name = f'layer{i + 1}'
+ module_i = getattr(self, module_name)
+ x = module_i(x)
+ if self.has_skip:
+ x = x + skip1[i] + skip2[i]
+ out.append(x)
+ out.reverse()
+
+ return tuple(out)
+
+
+class Upsample_unit(BaseModule):
+ """Upsample unit for upsample module.
+
+ Args:
+ ind (int): Indicates whether to interpolate (>0) and whether to
+ generate feature map for the next hourglass-like module.
+ num_units (int): Number of units that form a upsample module. Along
+ with ind and gen_cross_conv, nm_units is used to decide whether
+ to generate feature map for the next hourglass-like module.
+ in_channels (int): Channel number of the skip-in feature maps from
+ the corresponding downsample unit.
+ unit_channels (int): Channel number in this unit. Default:256.
+ gen_skip: (bool): Whether or not to generate skips for the posterior
+ downsample module. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ out_channels (in): Number of channels of feature output by upsample
+ module. Must equal to in_channels of downsample module. Default:64
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ ind,
+ num_units,
+ in_channels,
+ unit_channels=256,
+ gen_skip=False,
+ gen_cross_conv=False,
+ norm_cfg=dict(type='BN'),
+ out_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.num_units = num_units
+ self.norm_cfg = norm_cfg
+ self.in_skip = ConvModule(
+ in_channels,
+ unit_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.ind = ind
+ if self.ind > 0:
+ self.up_conv = ConvModule(
+ unit_channels,
+ unit_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None,
+ inplace=True)
+
+ self.gen_skip = gen_skip
+ if self.gen_skip:
+ self.out_skip1 = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ self.out_skip2 = ConvModule(
+ unit_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ self.gen_cross_conv = gen_cross_conv
+ if self.ind == num_units - 1 and self.gen_cross_conv:
+ self.cross_conv = ConvModule(
+ unit_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=self.norm_cfg,
+ inplace=True)
+
+ def forward(self, x, up_x):
+ out = self.in_skip(x)
+
+ if self.ind > 0:
+ up_x = F.interpolate(
+ up_x,
+ size=(x.size(2), x.size(3)),
+ mode='bilinear',
+ align_corners=True)
+ up_x = self.up_conv(up_x)
+ out = out + up_x
+ out = self.relu(out)
+
+ skip1 = None
+ skip2 = None
+ if self.gen_skip:
+ skip1 = self.out_skip1(x)
+ skip2 = self.out_skip2(out)
+
+ cross_conv = None
+ if self.ind == self.num_units - 1 and self.gen_cross_conv:
+ cross_conv = self.cross_conv(out)
+
+ return out, skip1, skip2, cross_conv
+
+
+class Upsample_module(BaseModule):
+ """Upsample module for RSN.
+
+ Args:
+ unit_channels (int): Channel number in the upsample units.
+ Default:256.
+ num_units (int): Numbers of upsample units. Default: 4
+ gen_skip (bool): Whether to generate skip for posterior downsample
+ module or not. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ out_channels (int): Number of channels of feature output by upsample
+ module. Must equal to in_channels of downsample module. Default:64
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ unit_channels=256,
+ num_units=4,
+ gen_skip=False,
+ gen_cross_conv=False,
+ norm_cfg=dict(type='BN'),
+ out_channels=64,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = list()
+ for i in range(num_units):
+ self.in_channels.append(RSB.expansion * out_channels * pow(2, i))
+ self.in_channels.reverse()
+ self.num_units = num_units
+ self.gen_skip = gen_skip
+ self.gen_cross_conv = gen_cross_conv
+ self.norm_cfg = norm_cfg
+ for i in range(num_units):
+ module_name = f'up{i + 1}'
+ self.add_module(
+ module_name,
+ Upsample_unit(
+ i,
+ self.num_units,
+ self.in_channels[i],
+ unit_channels,
+ self.gen_skip,
+ self.gen_cross_conv,
+ norm_cfg=self.norm_cfg,
+ out_channels=64))
+
+ def forward(self, x):
+ out = list()
+ skip1 = list()
+ skip2 = list()
+ cross_conv = None
+ for i in range(self.num_units):
+ module_i = getattr(self, f'up{i + 1}')
+ if i == 0:
+ outi, skip1_i, skip2_i, _ = module_i(x[i], None)
+ elif i == self.num_units - 1:
+ outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
+ else:
+ outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
+ out.append(outi)
+ skip1.append(skip1_i)
+ skip2.append(skip2_i)
+ skip1.reverse()
+ skip2.reverse()
+
+ return out, skip1, skip2, cross_conv
+
+
+class Single_stage_RSN(BaseModule):
+ """Single_stage Residual Steps Network.
+
+ Args:
+ unit_channels (int): Channel number in the upsample units. Default:256.
+ num_units (int): Numbers of downsample/upsample units. Default: 4
+ gen_skip (bool): Whether to generate skip for posterior downsample
+ module or not. Default:False
+ gen_cross_conv (bool): Whether to generate feature map for the next
+ hourglass-like module. Default:False
+ has_skip (bool): Have skip connections from prior upsample
+ module or not. Default:False
+ num_steps (int): Number of steps in RSB. Default: 4
+ num_blocks (list): Number of blocks in each downsample unit.
+ Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ in_channels (int): Number of channels of the feature from ResNet_Top.
+ Default: 64.
+ expand_times (int): Times by which the in_channels are expanded in RSB.
+ Default:26.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ has_skip=False,
+ gen_skip=False,
+ gen_cross_conv=False,
+ unit_channels=256,
+ num_units=4,
+ num_steps=4,
+ num_blocks=[2, 2, 2, 2],
+ norm_cfg=dict(type='BN'),
+ in_channels=64,
+ expand_times=26,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ num_blocks = cp.deepcopy(num_blocks)
+ super().__init__(init_cfg=init_cfg)
+ assert len(num_blocks) == num_units
+ self.has_skip = has_skip
+ self.gen_skip = gen_skip
+ self.gen_cross_conv = gen_cross_conv
+ self.num_units = num_units
+ self.num_steps = num_steps
+ self.unit_channels = unit_channels
+ self.num_blocks = num_blocks
+ self.norm_cfg = norm_cfg
+
+ self.downsample = Downsample_module(RSB, num_blocks, num_steps,
+ num_units, has_skip, norm_cfg,
+ in_channels, expand_times)
+ self.upsample = Upsample_module(unit_channels, num_units, gen_skip,
+ gen_cross_conv, norm_cfg, in_channels)
+
+ def forward(self, x, skip1, skip2):
+ mid = self.downsample(x, skip1, skip2)
+ out, skip1, skip2, cross_conv = self.upsample(mid)
+
+ return out, skip1, skip2, cross_conv
+
+
+class ResNet_top(BaseModule):
+ """ResNet top for RSN.
+
+ Args:
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ channels (int): Number of channels of the feature output by ResNet_top.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, norm_cfg=dict(type='BN'), channels=64, init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.top = nn.Sequential(
+ ConvModule(
+ 3,
+ channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ norm_cfg=norm_cfg,
+ inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
+
+ def forward(self, img):
+ return self.top(img)
+
+
+@MODELS.register_module()
+class RSN(BaseBackbone):
+ """Residual Steps Network backbone. Paper ref: Cai et al. "Learning
+ Delicate Local Representations for Multi-Person Pose Estimation" (ECCV
+ 2020).
+
+ Args:
+ unit_channels (int): Number of Channels in an upsample unit.
+ Default: 256
+ num_stages (int): Number of stages in a multi-stage RSN. Default: 4
+ num_units (int): NUmber of downsample/upsample units in a single-stage
+ RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks)
+ num_blocks (list): Number of RSBs (Residual Steps Block) in each
+ downsample unit. Default: [2, 2, 2, 2]
+ num_steps (int): Number of steps in a RSB. Default:4
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ res_top_channels (int): Number of channels of feature from ResNet_top.
+ Default: 64.
+ expand_times (int): Times by which the in_channels are expanded in RSB.
+ Default:26.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(
+ type='Normal',
+ std=0.01,
+ layer=['Linear']),
+ ]``
+ Example:
+ >>> from mmpose.models import RSN
+ >>> import torch
+ >>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2])
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 511, 511)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... for feature in level_output:
+ ... print(tuple(feature.shape))
+ ...
+ (1, 256, 64, 64)
+ (1, 256, 128, 128)
+ (1, 256, 64, 64)
+ (1, 256, 128, 128)
+ """
+
+ def __init__(self,
+ unit_channels=256,
+ num_stages=4,
+ num_units=4,
+ num_blocks=[2, 2, 2, 2],
+ num_steps=4,
+ norm_cfg=dict(type='BN'),
+ res_top_channels=64,
+ expand_times=26,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(type='Normal', std=0.01, layer=['Linear']),
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = cp.deepcopy(norm_cfg)
+ num_blocks = cp.deepcopy(num_blocks)
+ super().__init__(init_cfg=init_cfg)
+ self.unit_channels = unit_channels
+ self.num_stages = num_stages
+ self.num_units = num_units
+ self.num_blocks = num_blocks
+ self.num_steps = num_steps
+ self.norm_cfg = norm_cfg
+
+ assert self.num_stages > 0
+ assert self.num_steps > 1
+ assert self.num_units > 1
+ assert self.num_units == len(self.num_blocks)
+ self.top = ResNet_top(norm_cfg=norm_cfg)
+ self.multi_stage_rsn = nn.ModuleList([])
+ for i in range(self.num_stages):
+ if i == 0:
+ has_skip = False
+ else:
+ has_skip = True
+ if i != self.num_stages - 1:
+ gen_skip = True
+ gen_cross_conv = True
+ else:
+ gen_skip = False
+ gen_cross_conv = False
+ self.multi_stage_rsn.append(
+ Single_stage_RSN(has_skip, gen_skip, gen_cross_conv,
+ unit_channels, num_units, num_steps,
+ num_blocks, norm_cfg, res_top_channels,
+ expand_times))
+
+ def forward(self, x):
+ """Model forward function."""
+ out_feats = []
+ skip1 = None
+ skip2 = None
+ x = self.top(x)
+ for i in range(self.num_stages):
+ out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2)
+ out_feats.append(out)
+
+ return out_feats
diff --git a/mmpose/models/backbones/scnet.py b/mmpose/models/backbones/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c802d256e711aa70c955ac5bb91d2f7ff724604
--- /dev/null
+++ b/mmpose/models/backbones/scnet.py
@@ -0,0 +1,252 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .resnet import Bottleneck, ResNet
+
+
+class SCConv(BaseModule):
+ """SCConv (Self-calibrated Convolution)
+
+ Args:
+ in_channels (int): The input channels of the SCConv.
+ out_channels (int): The output channel of the SCConv.
+ stride (int): stride of SCConv.
+ pooling_r (int): size of pooling for scconv.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ pooling_r,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.1),
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+
+ assert in_channels == out_channels
+
+ self.k2 = nn.Sequential(
+ nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(norm_cfg, in_channels)[1],
+ )
+ self.k3 = nn.Sequential(
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(norm_cfg, in_channels)[1],
+ )
+ self.k4 = nn.Sequential(
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ """Forward function."""
+ identity = x
+
+ out = torch.sigmoid(
+ torch.add(identity, F.interpolate(self.k2(x),
+ identity.size()[2:])))
+ out = torch.mul(self.k3(x), out)
+ out = self.k4(out)
+
+ return out
+
+
+class SCBottleneck(Bottleneck):
+ """SC(Self-calibrated) Bottleneck.
+
+ Args:
+ in_channels (int): The input channels of the SCBottleneck block.
+ out_channels (int): The output channel of the SCBottleneck block.
+ """
+
+ pooling_r = 4
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.mid_channels = out_channels // self.expansion // 2
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+
+ self.k1 = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, self.mid_channels)[1],
+ nn.ReLU(inplace=True))
+
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride,
+ self.pooling_r, self.conv_cfg, self.norm_cfg)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels * 2,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out_a = self.conv1(x)
+ out_a = self.norm1(out_a)
+ out_a = self.relu(out_a)
+
+ out_a = self.k1(out_a)
+
+ out_b = self.conv2(x)
+ out_b = self.norm2(out_b)
+ out_b = self.relu(out_b)
+
+ out_b = self.scconv(out_b)
+
+ out = self.conv3(torch.cat([out_a, out_b], dim=1))
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@MODELS.register_module()
+class SCNet(ResNet):
+ """SCNet backbone.
+
+ Improving Convolutional Networks with Self-Calibrated Convolutions,
+ Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng,
+ IEEE CVPR, 2020.
+ http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf
+
+ Args:
+ depth (int): Depth of scnet, from {50, 101}.
+ in_channels (int): Number of input image channels. Normally 3.
+ base_channels (int): Number of base channels of hidden layer.
+ num_stages (int): SCNet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from mmpose.models import SCNet
+ >>> import torch
+ >>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 56, 56)
+ (1, 512, 28, 28)
+ (1, 1024, 14, 14)
+ (1, 2048, 7, 7)
+ """
+
+ arch_settings = {
+ 50: (SCBottleneck, [3, 4, 6, 3]),
+ 101: (SCBottleneck, [3, 4, 23, 3])
+ }
+
+ def __init__(self, depth, **kwargs):
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for SCNet')
+ super().__init__(depth, **kwargs)
diff --git a/mmpose/models/backbones/seresnet.py b/mmpose/models/backbones/seresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..617a1b72bee737ef0f3fb305e83ce33d8c8a7ea1
--- /dev/null
+++ b/mmpose/models/backbones/seresnet.py
@@ -0,0 +1,134 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.utils.checkpoint as cp
+
+from mmpose.registry import MODELS
+from .resnet import Bottleneck, ResLayer, ResNet
+from .utils.se_layer import SELayer
+
+
+class SEBottleneck(Bottleneck):
+ """SEBottleneck block for SEResNet.
+
+ Args:
+ in_channels (int): The input channels of the SEBottleneck block.
+ out_channels (int): The output channel of the SEBottleneck block.
+ se_ratio (int): Squeeze ratio in SELayer. Default: 16
+ """
+
+ def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.se_layer = SELayer(out_channels, ratio=se_ratio)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ out = self.se_layer(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@MODELS.register_module()
+class SEResNet(ResNet):
+ """SEResNet backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {50, 101, 152}.
+ se_ratio (int): Squeeze ratio in SELayer. Default: 16.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import SEResNet
+ >>> import torch
+ >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 56, 56)
+ (1, 512, 28, 28)
+ (1, 1024, 14, 14)
+ (1, 2048, 7, 7)
+ """
+
+ arch_settings = {
+ 50: (SEBottleneck, (3, 4, 6, 3)),
+ 101: (SEBottleneck, (3, 4, 23, 3)),
+ 152: (SEBottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, depth, se_ratio=16, **kwargs):
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for SEResNet')
+ self.se_ratio = se_ratio
+ super().__init__(depth, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(se_ratio=self.se_ratio, **kwargs)
diff --git a/mmpose/models/backbones/seresnext.py b/mmpose/models/backbones/seresnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1f5a6c8f3fe6b602aceb331781cd119958518b7
--- /dev/null
+++ b/mmpose/models/backbones/seresnext.py
@@ -0,0 +1,179 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpose.registry import MODELS
+from .resnet import ResLayer
+from .seresnet import SEBottleneck as _SEBottleneck
+from .seresnet import SEResNet
+
+
+class SEBottleneck(_SEBottleneck):
+ """SEBottleneck block for SEResNeXt.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ base_channels (int): Middle channels of the first stage. Default: 64.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None
+ se_ratio (int): Squeeze ratio in SELayer. Default: 16
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ base_channels=64,
+ groups=32,
+ width_per_group=4,
+ se_ratio=16,
+ **kwargs):
+ super().__init__(in_channels, out_channels, se_ratio, **kwargs)
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ # We follow the same rational of ResNext to compute mid_channels.
+ # For SEResNet bottleneck, middle channels are determined by expansion
+ # and out_channels, but for SEResNeXt bottleneck, it is determined by
+ # groups and width_per_group and the stage it is located in.
+ if groups != 1:
+ assert self.mid_channels % base_channels == 0
+ self.mid_channels = (
+ groups * width_per_group * self.mid_channels // base_channels)
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@MODELS.register_module()
+class SEResNeXt(SEResNet):
+ """SEResNeXt backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {50, 101, 152}.
+ groups (int): Groups of conv2 in Bottleneck. Default: 32.
+ width_per_group (int): Width per group of conv2 in Bottleneck.
+ Default: 4.
+ se_ratio (int): Squeeze ratio in SELayer. Default: 16.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import SEResNeXt
+ >>> import torch
+ >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 56, 56)
+ (1, 512, 28, 28)
+ (1, 1024, 14, 14)
+ (1, 2048, 7, 7)
+ """
+
+ arch_settings = {
+ 50: (SEBottleneck, (3, 4, 6, 3)),
+ 101: (SEBottleneck, (3, 4, 23, 3)),
+ 152: (SEBottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
+ self.groups = groups
+ self.width_per_group = width_per_group
+ super().__init__(depth, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(
+ groups=self.groups,
+ width_per_group=self.width_per_group,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmpose/models/backbones/shufflenet_v1.py b/mmpose/models/backbones/shufflenet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..17491910e9c1c2ec4eea04ca715dc91293f00cd4
--- /dev/null
+++ b/mmpose/models/backbones/shufflenet_v1.py
@@ -0,0 +1,338 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule, build_activation_layer
+from mmengine.model import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import channel_shuffle, make_divisible
+
+
+class ShuffleUnit(BaseModule):
+ """ShuffleUnit block.
+
+ ShuffleNet unit with pointwise group convolution (GConv) and channel
+ shuffle.
+
+ Args:
+ in_channels (int): The input channels of the ShuffleUnit.
+ out_channels (int): The output channels of the ShuffleUnit.
+ groups (int, optional): The number of groups to be used in grouped 1x1
+ convolutions in each ShuffleUnit. Default: 3
+ first_block (bool, optional): Whether it is the first ShuffleUnit of a
+ sequential ShuffleUnits. Default: True, which means not using the
+ grouped 1x1 convolution.
+ combine (str, optional): The ways to combine the input and output
+ branches. Default: 'add'.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ groups=3,
+ first_block=True,
+ combine='add',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.first_block = first_block
+ self.combine = combine
+ self.groups = groups
+ self.bottleneck_channels = self.out_channels // 4
+ self.with_cp = with_cp
+
+ if self.combine == 'add':
+ self.depthwise_stride = 1
+ self._combine_func = self._add
+ assert in_channels == out_channels, (
+ 'in_channels must be equal to out_channels when combine '
+ 'is add')
+ elif self.combine == 'concat':
+ self.depthwise_stride = 2
+ self._combine_func = self._concat
+ self.out_channels -= self.in_channels
+ self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ else:
+ raise ValueError(f'Cannot combine tensors with {self.combine}. '
+ 'Only "add" and "concat" are supported')
+
+ self.first_1x1_groups = 1 if first_block else self.groups
+ self.g_conv_1x1_compress = ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.bottleneck_channels,
+ kernel_size=1,
+ groups=self.first_1x1_groups,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.depthwise_conv3x3_bn = ConvModule(
+ in_channels=self.bottleneck_channels,
+ out_channels=self.bottleneck_channels,
+ kernel_size=3,
+ stride=self.depthwise_stride,
+ padding=1,
+ groups=self.bottleneck_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.g_conv_1x1_expand = ConvModule(
+ in_channels=self.bottleneck_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ groups=self.groups,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.act = build_activation_layer(act_cfg)
+
+ @staticmethod
+ def _add(x, out):
+ # residual connection
+ return x + out
+
+ @staticmethod
+ def _concat(x, out):
+ # concatenate along channel axis
+ return torch.cat((x, out), 1)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.g_conv_1x1_compress(x)
+ out = self.depthwise_conv3x3_bn(out)
+
+ if self.groups > 1:
+ out = channel_shuffle(out, self.groups)
+
+ out = self.g_conv_1x1_expand(out)
+
+ if self.combine == 'concat':
+ residual = self.avgpool(residual)
+ out = self.act(out)
+ out = self._combine_func(residual, out)
+ else:
+ out = self._combine_func(residual, out)
+ out = self.act(out)
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+@MODELS.register_module()
+class ShuffleNetV1(BaseBackbone):
+ """ShuffleNetV1 backbone.
+
+ Args:
+ groups (int, optional): The number of groups to be used in grouped 1x1
+ convolutions in each ShuffleUnit. Default: 3.
+ widen_factor (float, optional): Width multiplier - adjusts the number
+ of channels in each layer by this amount. Default: 1.0.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (2, )
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.01, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ bias=0.0001
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ def __init__(self,
+ groups=3,
+ widen_factor=1.0,
+ out_indices=(2, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Normal', std=0.01, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ bias=0.0001,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.stage_blocks = [4, 8, 4]
+ self.groups = groups
+
+ for index in out_indices:
+ if index not in range(0, 3):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 3). But received {index}')
+
+ if frozen_stages not in range(-1, 3):
+ raise ValueError('frozen_stages must be in range(-1, 3). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ if groups == 1:
+ channels = (144, 288, 576)
+ elif groups == 2:
+ channels = (200, 400, 800)
+ elif groups == 3:
+ channels = (240, 480, 960)
+ elif groups == 4:
+ channels = (272, 544, 1088)
+ elif groups == 8:
+ channels = (384, 768, 1536)
+ else:
+ raise ValueError(f'{groups} groups is not supported for 1x1 '
+ 'Grouped Convolutions')
+
+ channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
+
+ self.in_channels = int(24 * widen_factor)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layers = nn.ModuleList()
+ for i, num_blocks in enumerate(self.stage_blocks):
+ first_block = (i == 0)
+ layer = self.make_layer(channels[i], num_blocks, first_block)
+ self.layers.append(layer)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(self.frozen_stages):
+ layer = self.layers[i]
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ super(ShuffleNetV1, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ return
+
+ for name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d) and 'conv1' not in name:
+ nn.init.normal_(m.weight, mean=0, std=1.0 / m.weight.shape[1])
+
+ def make_layer(self, out_channels, num_blocks, first_block=False):
+ """Stack ShuffleUnit blocks to make a layer.
+
+ Args:
+ out_channels (int): out_channels of the block.
+ num_blocks (int): Number of blocks.
+ first_block (bool, optional): Whether is the first ShuffleUnit of a
+ sequential ShuffleUnits. Default: False, which means using
+ the grouped 1x1 convolution.
+ """
+ layers = []
+ for i in range(num_blocks):
+ first_block = first_block if i == 0 else False
+ combine_mode = 'concat' if i == 0 else 'add'
+ layers.append(
+ ShuffleUnit(
+ self.in_channels,
+ out_channels,
+ groups=self.groups,
+ first_block=first_block,
+ combine=combine_mode,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.maxpool(x)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/shufflenet_v2.py b/mmpose/models/backbones/shufflenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9757841e73bf547fde77cf847a917c46acfb0b00
--- /dev/null
+++ b/mmpose/models/backbones/shufflenet_v2.py
@@ -0,0 +1,311 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import channel_shuffle
+
+
+class InvertedResidual(BaseModule):
+ """InvertedResidual block for ShuffleNetV2 backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ out_channels (int): The output channels of the block.
+ stride (int): Stride of the 3x3 convolution layer. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.stride = stride
+ self.with_cp = with_cp
+
+ branch_features = out_channels // 2
+ if self.stride == 1:
+ assert in_channels == branch_features * 2, (
+ f'in_channels ({in_channels}) should equal to '
+ f'branch_features * 2 ({branch_features * 2}) '
+ 'when stride is 1')
+
+ if in_channels != branch_features * 2:
+ assert self.stride != 1, (
+ f'stride ({self.stride}) should not equal 1 when '
+ f'in_channels != branch_features * 2')
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ groups=in_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None),
+ ConvModule(
+ in_channels,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ )
+
+ self.branch2 = nn.Sequential(
+ ConvModule(
+ in_channels if (self.stride > 1) else branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ branch_features,
+ branch_features,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ groups=branch_features,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None),
+ ConvModule(
+ branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.stride > 1:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+ else:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+
+ out = channel_shuffle(out, 2)
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+@MODELS.register_module()
+class ShuffleNetV2(BaseBackbone):
+ """ShuffleNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier - adjusts the number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (0, 1, 2, 3).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.01, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ bias=0.0001
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ def __init__(self,
+ widen_factor=1.0,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Normal', std=0.01, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ bias=0.0001,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.stage_blocks = [4, 8, 4]
+ for index in out_indices:
+ if index not in range(0, 4):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 4). But received {index}')
+
+ if frozen_stages not in range(-1, 4):
+ raise ValueError('frozen_stages must be in range(-1, 4). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ if widen_factor == 0.5:
+ channels = [48, 96, 192, 1024]
+ elif widen_factor == 1.0:
+ channels = [116, 232, 464, 1024]
+ elif widen_factor == 1.5:
+ channels = [176, 352, 704, 1024]
+ elif widen_factor == 2.0:
+ channels = [244, 488, 976, 2048]
+ else:
+ raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
+ f'But received {widen_factor}')
+
+ self.in_channels = 24
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layers = nn.ModuleList()
+ for i, num_blocks in enumerate(self.stage_blocks):
+ layer = self._make_layer(channels[i], num_blocks)
+ self.layers.append(layer)
+
+ output_channels = channels[-1]
+ self.layers.append(
+ ConvModule(
+ in_channels=self.in_channels,
+ out_channels=output_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def _make_layer(self, out_channels, num_blocks):
+ """Stack blocks to make a layer.
+
+ Args:
+ out_channels (int): out_channels of the block.
+ num_blocks (int): number of blocks.
+ """
+ layers = []
+ for i in range(num_blocks):
+ stride = 2 if i == 0 else 1
+ layers.append(
+ InvertedResidual(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+
+ for i in range(self.frozen_stages):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ super(ShuffleNetV2, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ return
+
+ for name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d) and 'conv1' not in name:
+ nn.init.normal_(m.weight, mean=0, std=1.0 / m.weight.shape[1])
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.maxpool(x)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmpose/models/backbones/swin.py b/mmpose/models/backbones/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8f7c972787c19f64eb398615966722c5bdcd533
--- /dev/null
+++ b/mmpose/models/backbones/swin.py
@@ -0,0 +1,739 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, build_dropout
+from mmengine.model import BaseModule
+from mmengine.model.weight_init import trunc_normal_
+from mmengine.runner import load_state_dict
+from mmengine.utils import to_2tuple
+
+from mmpose.registry import MODELS
+from mmpose.utils import get_root_logger
+from ..utils.transformer import PatchEmbed, PatchMerging
+from .base_backbone import BaseBackbone
+from .utils import get_state_dict
+from .utils.ckpt_convert import swin_converter
+
+
+class WindowMSA(BaseModule):
+ """Window based multi-head self-attention (W-MSA) module with relative
+ position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # About 2x faster than original impl
+ Wh, Ww = self.window_size
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
+ rel_position_index = rel_index_coords + rel_index_coords.T
+ rel_position_index = rel_position_index.flip(1).contiguous()
+ self.register_buffer('relative_position_index', rel_position_index)
+
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def init_weights(self):
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+
+ x (tensor): input features with shape of (num_windows*B, N, C)
+ mask (tensor | None, Optional): mask with shape of (num_windows,
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
+ """
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # make torchscript happy (cannot use tensor as tuple)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class ShiftWindowMSA(BaseModule):
+ """Shifted Window Multihead Self-Attention Module.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window.
+ shift_size (int, optional): The shift step of each window towards
+ right-bottom. If zero, act as regular window-msa. Defaults to 0.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Defaults: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Defaults: 0.
+ proj_drop_rate (float, optional): Dropout ratio of output.
+ Defaults: 0.
+ dropout_layer (dict, optional): The dropout_layer used before output.
+ Defaults: dict(type='DropPath', drop_prob=0.).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ shift_size=0,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0,
+ proj_drop_rate=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.window_size = window_size
+ self.shift_size = shift_size
+ assert 0 <= self.shift_size < self.window_size
+
+ self.w_msa = WindowMSA(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=to_2tuple(window_size),
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate)
+
+ self.drop = build_dropout(dropout_layer)
+
+ def forward(self, query, hw_shape):
+ B, L, C = query.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+ query = query.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
+ H_pad, W_pad = query.shape[1], query.shape[2]
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_query = torch.roll(
+ query,
+ shifts=(-self.shift_size, -self.shift_size),
+ dims=(1, 2))
+
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ # nW, window_size, window_size, 1
+ mask_windows = self.window_partition(img_mask)
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ else:
+ shifted_query = query
+ attn_mask = None
+
+ # nW*B, window_size, window_size, C
+ query_windows = self.window_partition(shifted_query)
+ # nW*B, window_size*window_size, C
+ query_windows = query_windows.view(-1, self.window_size**2, C)
+
+ # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
+ attn_windows = self.w_msa(query_windows, mask=attn_mask)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+
+ # B H' W' C
+ shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ x = self.drop(x)
+ return x
+
+ def window_reverse(self, windows, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ window_size = self.window_size
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ def window_partition(self, x):
+ """
+ Args:
+ x: (B, H, W, C)
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ window_size = self.window_size
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+
+class SwinBlock(BaseModule):
+ """"
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ window_size (int, optional): The local window scale. Default: 7.
+ shift (bool, optional): whether to shift window or not. Default False.
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
+ drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ window_size=7,
+ shift=False,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+
+ super(SwinBlock, self).__init__(init_cfg=init_cfg)
+
+ self.with_cp = with_cp
+
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.attn = ShiftWindowMSA(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=window_size // 2 if shift else 0,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
+
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ num_fcs=2,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ add_identity=True,
+ init_cfg=None)
+
+ def forward(self, x, hw_shape):
+
+ def _inner_forward(x):
+ identity = x
+ x = self.norm1(x)
+ x = self.attn(x, hw_shape)
+
+ x = x + identity
+
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+
+ return x
+
+
+class SwinBlockSequence(BaseModule):
+ """Implements one stage in Swin Transformer.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ depth (int): The number of blocks in this stage.
+ window_size (int, optional): The local window scale. Default: 7.
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
+ drop_path_rate (float | list[float], optional): Stochastic depth
+ rate. Default: 0.
+ downsample (nn.Module | None, optional): The downsample operation
+ module. Default: None.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ depth,
+ window_size=7,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ downsample=None,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(drop_path_rate, list):
+ drop_path_rates = drop_path_rate
+ assert len(drop_path_rates) == depth
+ else:
+ drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = SwinBlock(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=feedforward_channels,
+ window_size=window_size,
+ shift=False if i % 2 == 0 else True,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rates[i],
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp)
+ self.blocks.append(block)
+
+ self.downsample = downsample
+
+ def forward(self, x, hw_shape):
+ for block in self.blocks:
+ x = block(x, hw_shape)
+
+ if self.downsample:
+ x_down, down_hw_shape = self.downsample(x, hw_shape)
+ return x_down, down_hw_shape, x, hw_shape
+ else:
+ return x, hw_shape, x, hw_shape
+
+
+@MODELS.register_module()
+class SwinTransformer(BaseBackbone):
+ """ Swin Transformer
+ A PyTorch implement of : `Swin Transformer:
+ Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/abs/2103.14030
+
+ Inspiration from
+ https://github.com/microsoft/Swin-Transformer
+
+ Args:
+ pretrain_img_size (int | tuple[int]): The size of input image when
+ pretrain. Defaults: 224.
+ in_channels (int): The num of input channels.
+ Defaults: 3.
+ embed_dims (int): The feature dimension. Default: 96.
+ patch_size (int | tuple[int]): Patch size. Default: 4.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ Default: (2, 2, 6, 2).
+ num_heads (tuple[int]): Parallel attention heads of each Swin
+ Transformer stage. Default: (3, 6, 12, 24).
+ strides (tuple[int]): The patch merging or patch embedding stride of
+ each Swin Transformer stage. (In swin, we set kernel size equal to
+ stride.) Default: (4, 2, 2, 2).
+ out_indices (tuple[int]): Output from which stages.
+ Default: (0, 1, 2, 3).
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key,
+ value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ patch_norm (bool): If add a norm layer for patch embed and patch
+ merging. Default: True.
+ drop_rate (float): Dropout rate. Defaults: 0.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
+ the patch embedding. Defaults: False.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LN').
+ norm_cfg (dict): Config dict for normalization layer at
+ output of backone. Defaults: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ pretrained (str, optional): model pretrained path. Default: None.
+ convert_weights (bool): The flag indicates whether the
+ pre-trained model is from the original repo. We may need
+ to convert some keys to make it compatible.
+ Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ Default: -1 (-1 means not freezing any parameters).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: ``[
+ dict(type='TruncNormal', std=.02, layer=['Linear']),
+ dict(type='Constant', val=1, layer=['LayerNorm']),
+ ]``
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ in_channels=3,
+ embed_dims=96,
+ patch_size=4,
+ window_size=7,
+ mlp_ratio=4,
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ strides=(4, 2, 2, 2),
+ out_indices=(0, 1, 2, 3),
+ qkv_bias=True,
+ qk_scale=None,
+ patch_norm=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ use_abs_pos_embed=False,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ convert_weights=False,
+ frozen_stages=-1,
+ init_cfg=[
+ dict(type='TruncNormal', std=.02, layer=['Linear']),
+ dict(type='Constant', val=1, layer=['LayerNorm']),
+ ]):
+ self.convert_weights = convert_weights
+ self.frozen_stages = frozen_stages
+ if isinstance(pretrain_img_size, int):
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ elif isinstance(pretrain_img_size, tuple):
+ if len(pretrain_img_size) == 1:
+ pretrain_img_size = to_2tuple(pretrain_img_size[0])
+ assert len(pretrain_img_size) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pretrain_img_size)}'
+
+ super(SwinTransformer, self).__init__(init_cfg=init_cfg)
+
+ num_layers = len(depths)
+ self.out_indices = out_indices
+ self.use_abs_pos_embed = use_abs_pos_embed
+
+ assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
+
+ self.patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=strides[0],
+ norm_cfg=norm_cfg if patch_norm else None,
+ init_cfg=None)
+
+ if self.use_abs_pos_embed:
+ patch_row = pretrain_img_size[0] // patch_size
+ patch_col = pretrain_img_size[1] // patch_size
+ num_patches = patch_row * patch_col
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros((1, num_patches, embed_dims)))
+
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
+
+ # set stochastic depth decay rule
+ total_depth = sum(depths)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ]
+
+ self.stages = nn.ModuleList()
+ in_channels = embed_dims
+ for i in range(num_layers):
+ if i < num_layers - 1:
+ downsample = PatchMerging(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ stride=strides[i + 1],
+ norm_cfg=norm_cfg if patch_norm else None,
+ init_cfg=None)
+ else:
+ downsample = None
+
+ stage = SwinBlockSequence(
+ embed_dims=in_channels,
+ num_heads=num_heads[i],
+ feedforward_channels=mlp_ratio * in_channels,
+ depth=depths[i],
+ window_size=window_size,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
+ downsample=downsample,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp)
+ self.stages.append(stage)
+ if downsample:
+ in_channels = downsample.out_channels
+
+ self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
+ # Add a norm layer for each output
+ for i in out_indices:
+ layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
+ layer_name = f'norm{i}'
+ self.add_module(layer_name, layer)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ if self.use_abs_pos_embed:
+ self.absolute_pos_embed.requires_grad = False
+ self.drop_after_pos.eval()
+
+ for i in range(1, self.frozen_stages + 1):
+
+ if (i - 1) in self.out_indices:
+ norm_layer = getattr(self, f'norm{i-1}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ m = self.stages[i - 1]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress zero_init_residual if use pretrained model.
+ logger = get_root_logger()
+ state_dict = get_state_dict(
+ self.init_cfg['checkpoint'], map_location='cpu')
+ if self.convert_weights:
+ # supported loading weight from original repo
+ state_dict = swin_converter(state_dict)
+
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = self.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H * W:
+ logger.warning('Error in loading absolute_pos_embed, pass')
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
+ N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [
+ k for k in state_dict.keys()
+ if 'relative_position_bias_table' in k
+ ]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = self.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f'Error in loading {table_key}, pass')
+ elif L1 != L2:
+ S1 = int(L1**0.5)
+ S2 = int(L2**0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
+ size=(S2, S2),
+ mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(
+ nH2, L2).permute(1, 0).contiguous()
+
+ # load state_dict
+ load_state_dict(self, state_dict, strict=False, logger=logger)
+
+ else:
+ super(SwinTransformer, self).init_weights()
+ if self.use_abs_pos_embed:
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ def forward(self, x):
+ x, hw_shape = self.patch_embed(x)
+
+ if self.use_abs_pos_embed:
+ x = x + self.absolute_pos_embed
+ x = self.drop_after_pos(x)
+
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ out = norm_layer(out)
+ out = out.view(-1, *out_hw_shape,
+ self.num_features[i]).permute(0, 3, 1,
+ 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/mmpose/models/backbones/tcn.py b/mmpose/models/backbones/tcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef49a1ff075288cc7a23f51f47c5b1bcdd383894
--- /dev/null
+++ b/mmpose/models/backbones/tcn.py
@@ -0,0 +1,284 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_conv_layer
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from ..utils.regularizations import WeightNormClipHook
+from .base_backbone import BaseBackbone
+
+
+class BasicTemporalBlock(BaseModule):
+ """Basic block for VideoPose3D.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ mid_channels (int): The output channels of conv1. Default: 1024.
+ kernel_size (int): Size of the convolving kernel. Default: 3.
+ dilation (int): Spacing between kernel elements. Default: 3.
+ dropout (float): Dropout rate. Default: 0.25.
+ causal (bool): Use causal convolutions instead of symmetric
+ convolutions (for real-time applications). Default: False.
+ residual (bool): Use residual connection. Default: True.
+ use_stride_conv (bool): Use optimized TCN that designed
+ specifically for single-frame batching, i.e. where batches have
+ input length = receptive field, and output length = 1. This
+ implementation replaces dilated convolutions with strided
+ convolutions to avoid generating unused intermediate results.
+ Default: False.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d').
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN1d').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels=1024,
+ kernel_size=3,
+ dilation=3,
+ dropout=0.25,
+ causal=False,
+ residual=True,
+ use_stride_conv=False,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=dict(type='BN1d'),
+ init_cfg=None):
+ # Protect mutable default arguments
+ conv_cfg = copy.deepcopy(conv_cfg)
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.mid_channels = mid_channels
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.dropout = dropout
+ self.causal = causal
+ self.residual = residual
+ self.use_stride_conv = use_stride_conv
+
+ self.pad = (kernel_size - 1) * dilation // 2
+ if use_stride_conv:
+ self.stride = kernel_size
+ self.causal_shift = kernel_size // 2 if causal else 0
+ self.dilation = 1
+ else:
+ self.stride = 1
+ self.causal_shift = kernel_size // 2 * dilation if causal else 0
+
+ self.conv1 = nn.Sequential(
+ ConvModule(
+ in_channels,
+ mid_channels,
+ kernel_size=kernel_size,
+ stride=self.stride,
+ dilation=self.dilation,
+ bias='auto',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+ self.conv2 = nn.Sequential(
+ ConvModule(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ bias='auto',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+
+ if residual and in_channels != out_channels:
+ self.short_cut = build_conv_layer(conv_cfg, in_channels,
+ out_channels, 1)
+ else:
+ self.short_cut = None
+
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
+
+ def forward(self, x):
+ """Forward function."""
+ if self.use_stride_conv:
+ assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
+ else:
+ assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
+ self.pad + self.causal_shift <= x.shape[2]
+
+ out = self.conv1(x)
+ if self.dropout is not None:
+ out = self.dropout(out)
+
+ out = self.conv2(out)
+ if self.dropout is not None:
+ out = self.dropout(out)
+
+ if self.residual:
+ if self.use_stride_conv:
+ res = x[:, :, self.causal_shift +
+ self.kernel_size // 2::self.kernel_size]
+ else:
+ res = x[:, :,
+ (self.pad + self.causal_shift):(x.shape[2] - self.pad +
+ self.causal_shift)]
+
+ if self.short_cut is not None:
+ res = self.short_cut(res)
+ out = out + res
+
+ return out
+
+
+@MODELS.register_module()
+class TCN(BaseBackbone):
+ """TCN backbone.
+
+ Temporal Convolutional Networks.
+ More details can be found in the
+ `paper `__ .
+
+ Args:
+ in_channels (int): Number of input channels, which equals to
+ num_keypoints * num_features.
+ stem_channels (int): Number of feature channels. Default: 1024.
+ num_blocks (int): NUmber of basic temporal convolutional blocks.
+ Default: 2.
+ kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
+ each basic block. Default: ``(3, 3, 3)``.
+ dropout (float): Dropout rate. Default: 0.25.
+ causal (bool): Use causal convolutions instead of symmetric
+ convolutions (for real-time applications).
+ Default: False.
+ residual (bool): Use residual connection. Default: True.
+ use_stride_conv (bool): Use TCN backbone optimized for
+ single-frame batching, i.e. where batches have input length =
+ receptive field, and output length = 1. This implementation
+ replaces dilated convolutions with strided convolutions to avoid
+ generating unused intermediate results. The weights are
+ interchangeable with the reference implementation. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d').
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN1d').
+ max_norm (float|None): if not None, the weight of convolution layers
+ will be clipped to have a maximum norm of max_norm.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(
+ type='Kaiming',
+ mode='fan_in',
+ nonlinearity='relu',
+ layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+
+ Example:
+ >>> from mmpose.models import TCN
+ >>> import torch
+ >>> self = TCN(in_channels=34)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 34, 243)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 1024, 235)
+ (1, 1024, 217)
+ """
+
+ def __init__(self,
+ in_channels,
+ stem_channels=1024,
+ num_blocks=2,
+ kernel_sizes=(3, 3, 3),
+ dropout=0.25,
+ causal=False,
+ residual=True,
+ use_stride_conv=False,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=dict(type='BN1d'),
+ max_norm=None,
+ init_cfg=[
+ dict(
+ type='Kaiming',
+ mode='fan_in',
+ nonlinearity='relu',
+ layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ conv_cfg = copy.deepcopy(conv_cfg)
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__()
+ self.in_channels = in_channels
+ self.stem_channels = stem_channels
+ self.num_blocks = num_blocks
+ self.kernel_sizes = kernel_sizes
+ self.dropout = dropout
+ self.causal = causal
+ self.residual = residual
+ self.use_stride_conv = use_stride_conv
+ self.max_norm = max_norm
+
+ assert num_blocks == len(kernel_sizes) - 1
+ for ks in kernel_sizes:
+ assert ks % 2 == 1, 'Only odd filter widths are supported.'
+
+ self.expand_conv = ConvModule(
+ in_channels,
+ stem_channels,
+ kernel_size=kernel_sizes[0],
+ stride=kernel_sizes[0] if use_stride_conv else 1,
+ bias='auto',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg)
+
+ dilation = kernel_sizes[0]
+ self.tcn_blocks = nn.ModuleList()
+ for i in range(1, num_blocks + 1):
+ self.tcn_blocks.append(
+ BasicTemporalBlock(
+ in_channels=stem_channels,
+ out_channels=stem_channels,
+ mid_channels=stem_channels,
+ kernel_size=kernel_sizes[i],
+ dilation=dilation,
+ dropout=dropout,
+ causal=causal,
+ residual=residual,
+ use_stride_conv=use_stride_conv,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+ dilation *= kernel_sizes[i]
+
+ if self.max_norm is not None:
+ # Apply weight norm clip to conv layers
+ weight_clip = WeightNormClipHook(self.max_norm)
+ for module in self.modules():
+ if isinstance(module, nn.modules.conv._ConvNd):
+ weight_clip.register(module)
+
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.expand_conv(x)
+
+ if self.dropout is not None:
+ x = self.dropout(x)
+
+ outs = []
+ for i in range(self.num_blocks):
+ x = self.tcn_blocks[i](x)
+ outs.append(x)
+
+ return tuple(outs)
diff --git a/mmpose/models/backbones/utils/__init__.py b/mmpose/models/backbones/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07e42f89126c9e5663123794f92987b4f9b347f1
--- /dev/null
+++ b/mmpose/models/backbones/utils/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .channel_shuffle import channel_shuffle
+from .inverted_residual import InvertedResidual
+from .make_divisible import make_divisible
+from .se_layer import SELayer
+from .utils import get_state_dict, load_checkpoint
+
+__all__ = [
+ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
+ 'load_checkpoint', 'get_state_dict'
+]
diff --git a/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18849dde1a38fd3b8acbd319875d64b9a25873ba
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c81c5afae2980d7b9186af1c1a238b6f0dc87e7b
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/channel_shuffle.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0eabf9f401f7d7cc416f98fe636cb61d68c6f373
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/ckpt_convert.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c0591bcd4449bc76299c39a324cbd3ae9ec9c40
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/inverted_residual.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..881b3b9f6688d5a9442faeb8df836a43cb2d2a0d
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/make_divisible.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba43107fa90ef4e0822e649087196d6222e5da31
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/se_layer.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc b/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38312ac3b4e43d1c1271ec9039a0e384347f30a6
Binary files /dev/null and b/mmpose/models/backbones/utils/__pycache__/utils.cpython-38.pyc differ
diff --git a/mmpose/models/backbones/utils/channel_shuffle.py b/mmpose/models/backbones/utils/channel_shuffle.py
new file mode 100644
index 0000000000000000000000000000000000000000..aedd826bee690d42d92ed8a7f538b221e5b069e2
--- /dev/null
+++ b/mmpose/models/backbones/utils/channel_shuffle.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def channel_shuffle(x, groups):
+ """Channel Shuffle operation.
+
+ This function enables cross-group information flow for multiple groups
+ convolution layers.
+
+ Args:
+ x (Tensor): The input tensor.
+ groups (int): The number of groups to divide the input tensor
+ in the channel dimension.
+
+ Returns:
+ Tensor: The output tensor after channel shuffle operation.
+ """
+
+ batch_size, num_channels, height, width = x.size()
+ assert (num_channels % groups == 0), ('num_channels should be '
+ 'divisible by groups')
+ channels_per_group = num_channels // groups
+
+ x = x.view(batch_size, groups, channels_per_group, height, width)
+ x = torch.transpose(x, 1, 2).contiguous()
+ x = x.view(batch_size, groups * channels_per_group, height, width)
+
+ return x
diff --git a/mmpose/models/backbones/utils/ckpt_convert.py b/mmpose/models/backbones/utils/ckpt_convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a43892c6630be31e915ed1f8b9164ba250e8bd
--- /dev/null
+++ b/mmpose/models/backbones/utils/ckpt_convert.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# This script consists of several convert functions which
+# can modify the weights of model in original repo to be
+# pre-trained weights.
+
+from collections import OrderedDict
+
+
+def swin_converter(ckpt):
+
+ new_ckpt = OrderedDict()
+
+ def correct_unfold_reduction_order(x):
+ out_channel, in_channel = x.shape
+ x = x.reshape(out_channel, 4, in_channel // 4)
+ x = x[:, [0, 2, 1, 3], :].transpose(1,
+ 2).reshape(out_channel, in_channel)
+ return x
+
+ def correct_unfold_norm_order(x):
+ in_channel = x.shape[0]
+ x = x.reshape(4, in_channel // 4)
+ x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
+ return x
+
+ for k, v in ckpt.items():
+ if k.startswith('head'):
+ continue
+ elif k.startswith('layers'):
+ new_v = v
+ if 'attn.' in k:
+ new_k = k.replace('attn.', 'attn.w_msa.')
+ elif 'mlp.' in k:
+ if 'mlp.fc1.' in k:
+ new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
+ elif 'mlp.fc2.' in k:
+ new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
+ else:
+ new_k = k.replace('mlp.', 'ffn.')
+ elif 'downsample' in k:
+ new_k = k
+ if 'reduction.' in k:
+ new_v = correct_unfold_reduction_order(v)
+ elif 'norm.' in k:
+ new_v = correct_unfold_norm_order(v)
+ else:
+ new_k = k
+ new_k = new_k.replace('layers', 'stages', 1)
+ elif k.startswith('patch_embed'):
+ new_v = v
+ if 'proj' in k:
+ new_k = k.replace('proj', 'projection')
+ else:
+ new_k = k
+ else:
+ new_v = v
+ new_k = k
+
+ new_ckpt['backbone.' + new_k] = new_v
+
+ return new_ckpt
diff --git a/mmpose/models/backbones/utils/inverted_residual.py b/mmpose/models/backbones/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff762c570550e4a738ae1833a4c82c18777115d
--- /dev/null
+++ b/mmpose/models/backbones/utils/inverted_residual.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(nn.Module):
+ """Inverted Residual Block.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ groups (None or int): The group number of the depthwise convolution.
+ Default: None, which means group number = mid_channels.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels.
+ Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ groups=None,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ act_cfg = copy.deepcopy(act_cfg)
+ super().__init__()
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+
+ if groups is None:
+ groups = mid_channels
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=groups,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+
+ out = self.depthwise_conv(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.linear_conv(out)
+
+ if self.with_res_shortcut:
+ return x + out
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
diff --git a/mmpose/models/backbones/utils/make_divisible.py b/mmpose/models/backbones/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7666be65939d5c76057e73927c230029cb1871d
--- /dev/null
+++ b/mmpose/models/backbones/utils/make_divisible.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+
+ This function rounds the channel number down to the nearest value that can
+ be divisible by the divisor.
+
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int, optional): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float, optional): The minimum ratio of the rounded channel
+ number to the original channel number. Default: 0.9.
+ Returns:
+ int: The modified output channel number
+ """
+
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/mmpose/models/backbones/utils/se_layer.py b/mmpose/models/backbones/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec6d7aeaa9a990dbaf437b4ff4f4ba685e008245
--- /dev/null
+++ b/mmpose/models/backbones/utils/se_layer.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmengine
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+
+class SELayer(nn.Module):
+ """Squeeze-and-Excitation Module.
+
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configurated
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configurated by the first dict and the
+ second activation layer will be configurated by the second dict.
+ Default: (dict(type='ReLU'), dict(type='Sigmoid'))
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
+ super().__init__()
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmengine.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=int(channels / ratio),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=int(channels / ratio),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
diff --git a/mmpose/models/backbones/utils/utils.py b/mmpose/models/backbones/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc4fe40cd481391edf73872e2d4f6eb35592779
--- /dev/null
+++ b/mmpose/models/backbones/utils/utils.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+from mmengine.runner import CheckpointLoader, load_state_dict
+
+
+def load_checkpoint(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = CheckpointLoader.load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict_tmp = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict_tmp = checkpoint['model']
+ else:
+ state_dict_tmp = checkpoint
+
+ state_dict = OrderedDict()
+ # strip prefix of state_dict
+ for k, v in state_dict_tmp.items():
+ if k.startswith('module.backbone.'):
+ state_dict[k[16:]] = v
+ elif k.startswith('module.'):
+ state_dict[k[7:]] = v
+ elif k.startswith('backbone.'):
+ state_dict[k[9:]] = v
+ else:
+ state_dict[k] = v
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def get_state_dict(filename, map_location='cpu'):
+ """Get state_dict from a file or URI.
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``.
+ map_location (str): Same as :func:`torch.load`.
+
+ Returns:
+ OrderedDict: The state_dict.
+ """
+ checkpoint = CheckpointLoader.load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict_tmp = checkpoint['state_dict']
+ else:
+ state_dict_tmp = checkpoint
+
+ state_dict = OrderedDict()
+ # strip prefix of state_dict
+ for k, v in state_dict_tmp.items():
+ if k.startswith('module.backbone.'):
+ state_dict[k[16:]] = v
+ elif k.startswith('module.'):
+ state_dict[k[7:]] = v
+ elif k.startswith('backbone.'):
+ state_dict[k[9:]] = v
+ else:
+ state_dict[k] = v
+
+ return state_dict
diff --git a/mmpose/models/backbones/v2v_net.py b/mmpose/models/backbones/v2v_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd1ab93b105b345aabc0ace2c7e776cd99e36a9
--- /dev/null
+++ b/mmpose/models/backbones/v2v_net.py
@@ -0,0 +1,275 @@
+# ------------------------------------------------------------------------------
+# Copyright and License Information
+# Adapted from
+# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models/v2v_net.py
+# Original Licence: MIT License
+# ------------------------------------------------------------------------------
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class Basic3DBlock(BaseModule):
+ """A basic 3D convolutional block.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ kernel_size (int): Kernel size of the convolution operation
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv3d')
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN3d')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ conv_cfg=dict(type='Conv3d'),
+ norm_cfg=dict(type='BN3d'),
+ init_cfg=None):
+ super(Basic3DBlock, self).__init__(init_cfg=init_cfg)
+ self.block = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=((kernel_size - 1) // 2),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True)
+
+ def forward(self, x):
+ """Forward function."""
+ return self.block(x)
+
+
+class Res3DBlock(BaseModule):
+ """A residual 3D convolutional block.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ kernel_size (int): Kernel size of the convolution operation
+ Default: 3
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv3d')
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN3d')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ conv_cfg=dict(type='Conv3d'),
+ norm_cfg=dict(type='BN3d'),
+ init_cfg=None):
+ super(Res3DBlock, self).__init__(init_cfg=init_cfg)
+ self.res_branch = nn.Sequential(
+ ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=((kernel_size - 1) // 2),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True),
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=((kernel_size - 1) // 2),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ bias=True))
+
+ if in_channels == out_channels:
+ self.skip_con = nn.Sequential()
+ else:
+ self.skip_con = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ bias=True)
+
+ def forward(self, x):
+ """Forward function."""
+ res = self.res_branch(x)
+ skip = self.skip_con(x)
+ return F.relu(res + skip, True)
+
+
+class Pool3DBlock(BaseModule):
+ """A 3D max-pool block.
+
+ Args:
+ pool_size (int): Pool size of the 3D max-pool layer
+ """
+
+ def __init__(self, pool_size):
+ super(Pool3DBlock, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, x):
+ """Forward function."""
+ return F.max_pool3d(
+ x, kernel_size=self.pool_size, stride=self.pool_size)
+
+
+class Upsample3DBlock(BaseModule):
+ """A 3D upsample block.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ kernel_size (int): Kernel size of the transposed convolution operation.
+ Default: 2
+ stride (int): Kernel size of the transposed convolution operation.
+ Default: 2
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=2,
+ init_cfg=None):
+ super(Upsample3DBlock, self).__init__(init_cfg=init_cfg)
+ assert kernel_size == 2
+ assert stride == 2
+ self.block = nn.Sequential(
+ nn.ConvTranspose3d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=0,
+ output_padding=0), nn.BatchNorm3d(out_channels), nn.ReLU(True))
+
+ def forward(self, x):
+ """Forward function."""
+ return self.block(x)
+
+
+class EncoderDecorder(BaseModule):
+ """An encoder-decoder block.
+
+ Args:
+ in_channels (int): Input channels of this block
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self, in_channels=32, init_cfg=None):
+ super(EncoderDecorder, self).__init__(init_cfg=init_cfg)
+
+ self.encoder_pool1 = Pool3DBlock(2)
+ self.encoder_res1 = Res3DBlock(in_channels, in_channels * 2)
+ self.encoder_pool2 = Pool3DBlock(2)
+ self.encoder_res2 = Res3DBlock(in_channels * 2, in_channels * 4)
+
+ self.mid_res = Res3DBlock(in_channels * 4, in_channels * 4)
+
+ self.decoder_res2 = Res3DBlock(in_channels * 4, in_channels * 4)
+ self.decoder_upsample2 = Upsample3DBlock(in_channels * 4,
+ in_channels * 2, 2, 2)
+ self.decoder_res1 = Res3DBlock(in_channels * 2, in_channels * 2)
+ self.decoder_upsample1 = Upsample3DBlock(in_channels * 2, in_channels,
+ 2, 2)
+
+ self.skip_res1 = Res3DBlock(in_channels, in_channels)
+ self.skip_res2 = Res3DBlock(in_channels * 2, in_channels * 2)
+
+ def forward(self, x):
+ """Forward function."""
+ skip_x1 = self.skip_res1(x)
+ x = self.encoder_pool1(x)
+ x = self.encoder_res1(x)
+
+ skip_x2 = self.skip_res2(x)
+ x = self.encoder_pool2(x)
+ x = self.encoder_res2(x)
+
+ x = self.mid_res(x)
+
+ x = self.decoder_res2(x)
+ x = self.decoder_upsample2(x)
+ x = x + skip_x2
+
+ x = self.decoder_res1(x)
+ x = self.decoder_upsample1(x)
+ x = x + skip_x1
+
+ return x
+
+
+@MODELS.register_module()
+class V2VNet(BaseBackbone):
+ """V2VNet.
+
+ Please refer to the `paper `
+ for details.
+
+ Args:
+ input_channels (int):
+ Number of channels of the input feature volume.
+ output_channels (int):
+ Number of channels of the output volume.
+ mid_channels (int):
+ Input and output channels of the encoder-decoder block.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: ``dict(
+ type='Normal',
+ std=0.001,
+ layer=['Conv3d', 'ConvTranspose3d']
+ )``
+ """
+
+ def __init__(self,
+ input_channels,
+ output_channels,
+ mid_channels=32,
+ init_cfg=dict(
+ type='Normal',
+ std=0.001,
+ layer=['Conv3d', 'ConvTranspose3d'])):
+ super(V2VNet, self).__init__(init_cfg=init_cfg)
+
+ self.front_layers = nn.Sequential(
+ Basic3DBlock(input_channels, mid_channels // 2, 7),
+ Res3DBlock(mid_channels // 2, mid_channels),
+ )
+
+ self.encoder_decoder = EncoderDecorder(in_channels=mid_channels)
+
+ self.output_layer = nn.Conv3d(
+ mid_channels, output_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.front_layers(x)
+ x = self.encoder_decoder(x)
+ x = self.output_layer(x)
+
+ return (x, )
diff --git a/mmpose/models/backbones/vgg.py b/mmpose/models/backbones/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa09d8dc7ded75678e8e23846474acee763a532
--- /dev/null
+++ b/mmpose/models/backbones/vgg.py
@@ -0,0 +1,201 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+def make_vgg_layer(in_channels,
+ out_channels,
+ num_blocks,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ dilation=1,
+ with_norm=False,
+ ceil_mode=False):
+ layers = []
+ for _ in range(num_blocks):
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ dilation=dilation,
+ padding=dilation,
+ bias=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ layers.append(layer)
+ in_channels = out_channels
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+@MODELS.register_module()
+class VGG(BaseBackbone):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_norm (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. When it is None, the default behavior depends on
+ whether num_classes is specified. If num_classes <= 0, the default
+ value is (4, ), outputting the last feature map before classifier.
+ If num_classes > 0, the default value is (5, ), outputting the
+ classification score. Default: None.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
+ with_last_pool (bool): Whether to keep the last pooling before
+ classifier. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(
+ type='Normal',
+ std=0.01,
+ layer=['Linear']),
+ ]``
+ """
+
+ # Parameters to build layers. Each element specifies the number of conv in
+ # each stage. For example, VGG11 contains 11 layers with learnable
+ # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
+ # where 3 indicates the last three fully-connected layers.
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+
+ def __init__(self,
+ depth,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=None,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ norm_eval=False,
+ ceil_mode=False,
+ with_last_pool=True,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm']),
+ dict(type='Normal', std=0.01, layer=['Linear']),
+ ]):
+ super().__init__(init_cfg=init_cfg)
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+
+ self.num_classes = num_classes
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+ with_norm = norm_cfg is not None
+
+ if out_indices is None:
+ out_indices = (5, ) if num_classes > 0 else (4, )
+ assert max(out_indices) <= num_stages
+ self.out_indices = out_indices
+
+ self.in_channels = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ out_channels = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.in_channels,
+ out_channels,
+ num_blocks,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dilation=dilation,
+ with_norm=with_norm,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.in_channels = out_channels
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ m = vgg_layers[j]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/vipnas_mbv3.py b/mmpose/models/backbones/vipnas_mbv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9156cafa56d4f15766e48c77cd492e52345aed65
--- /dev/null
+++ b/mmpose/models/backbones/vipnas_mbv3.py
@@ -0,0 +1,173 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from mmcv.cnn import ConvModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+from .utils import InvertedResidual
+
+
+@MODELS.register_module()
+class ViPNAS_MobileNetV3(BaseBackbone):
+ """ViPNAS_MobileNetV3 backbone.
+
+ "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search"
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ wid (list(int)): Searched width config for each stage.
+ expan (list(int)): Searched expansion ratio config for each stage.
+ dep (list(int)): Searched depth config for each stage.
+ ks (list(int)): Searched kernel size config for each stage.
+ group (list(int)): Searched group number config for each stage.
+ att (list(bool)): Searched attention config for each stage.
+ stride (list(int)): Stride config for each stage.
+ act (list(dict)): Activation config for each stage.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ def __init__(
+ self,
+ wid=[16, 16, 24, 40, 80, 112, 160],
+ expan=[None, 1, 5, 4, 5, 5, 6],
+ dep=[None, 1, 4, 4, 4, 4, 4],
+ ks=[3, 3, 7, 7, 5, 7, 5],
+ group=[None, 8, 120, 20, 100, 280, 240],
+ att=[None, True, True, False, True, True, True],
+ stride=[2, 1, 2, 2, 2, 1, 2],
+ act=['HSwish', 'ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish', 'HSwish'],
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ frozen_stages=-1,
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ ):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ self.wid = wid
+ self.expan = expan
+ self.dep = dep
+ self.ks = ks
+ self.group = group
+ self.att = att
+ self.stride = stride
+ self.act = act
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.wid[0],
+ kernel_size=self.ks[0],
+ stride=self.stride[0],
+ padding=self.ks[0] // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type=self.act[0]))
+
+ self.layers = self._make_layer()
+
+ def _make_layer(self):
+ layers = []
+ layer_index = 0
+ for i, dep in enumerate(self.dep[1:]):
+ mid_channels = self.wid[i + 1] * self.expan[i + 1]
+
+ if self.att[i + 1]:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=1.0, divisor=2.0)))
+ else:
+ se_cfg = None
+
+ if self.expan[i + 1] == 1:
+ with_expand_conv = False
+ else:
+ with_expand_conv = True
+
+ for j in range(dep):
+ if j == 0:
+ stride = self.stride[i + 1]
+ in_channels = self.wid[i]
+ else:
+ stride = 1
+ in_channels = self.wid[i + 1]
+
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=self.wid[i + 1],
+ mid_channels=mid_channels,
+ kernel_size=self.ks[i + 1],
+ groups=self.group[i + 1],
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=with_expand_conv,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=self.act[i + 1]),
+ with_cp=self.with_cp)
+ layer_index += 1
+ layer_name = f'layer{layer_index}'
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+ return layers
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+
+ return (x, )
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/backbones/vipnas_resnet.py b/mmpose/models/backbones/vipnas_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be810b449c1a840c425c69e3d1d1340583e52ea
--- /dev/null
+++ b/mmpose/models/backbones/vipnas_resnet.py
@@ -0,0 +1,596 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
+from mmcv.cnn.bricks import ContextBlock
+from mmengine.model import BaseModule, Sequential
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpose.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class ViPNAS_Bottleneck(BaseModule):
+ """Bottleneck block for ViPNAS_ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the input/output channels of conv2. Default: 4.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module): downsample operation on identity branch.
+ Default: None.
+ style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
+ stride-two layer is the 3x3 conv layer, otherwise the stride-two
+ layer is the first 1x1 conv layer. Default: "pytorch".
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ kernel_size (int): kernel size of conv2 searched in ViPANS.
+ groups (int): group number of conv2 searched in ViPNAS.
+ attention (bool): whether to use attention module in the end of
+ the block.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=4,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ kernel_size=3,
+ groups=1,
+ attention=False,
+ init_cfg=None):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ assert style in ['pytorch', 'caffe']
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=kernel_size,
+ stride=self.conv2_stride,
+ padding=kernel_size // 2,
+ groups=groups,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ if attention:
+ self.attention = ContextBlock(out_channels,
+ max(1.0 / 16, 16.0 / out_channels))
+ else:
+ self.attention = None
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: the normalization layer named "norm3" """
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.attention is not None:
+ out = self.attention(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def get_expansion(block, expansion=None):
+ """Get the expansion of a residual block.
+
+ The block expansion will be obtained by the following order:
+
+ 1. If ``expansion`` is given, just return it.
+ 2. If ``block`` has the attribute ``expansion``, then return
+ ``block.expansion``.
+ 3. Return the default value according the the block type:
+ 4 for ``ViPNAS_Bottleneck``.
+
+ Args:
+ block (class): The block class.
+ expansion (int | None): The given expansion ratio.
+
+ Returns:
+ int: The expansion of the block.
+ """
+ if isinstance(expansion, int):
+ assert expansion > 0
+ elif expansion is None:
+ if hasattr(block, 'expansion'):
+ expansion = block.expansion
+ elif issubclass(block, ViPNAS_Bottleneck):
+ expansion = 1
+ else:
+ raise TypeError(f'expansion is not specified for {block.__name__}')
+ else:
+ raise TypeError('expansion must be an integer or None')
+
+ return expansion
+
+
+class ViPNAS_ResLayer(Sequential):
+ """ViPNAS_ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): Residual block used to build ViPNAS ResLayer.
+ num_blocks (int): Number of blocks.
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int, optional): The expansion for BasicBlock/Bottleneck.
+ If not specified, it will firstly be obtained via
+ ``block.expansion``. If the block has no attribute "expansion",
+ the following default values will be used: 1 for BasicBlock and
+ 4 for Bottleneck. Default: None.
+ stride (int): stride of the first block. Default: 1.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ kernel_size (int): Kernel Size of the corresponding convolution layer
+ searched in the block.
+ groups (int): Group number of the corresponding convolution layer
+ searched in the block.
+ attention (bool): Whether to use attention module in the end of the
+ block.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ in_channels,
+ out_channels,
+ expansion=None,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ kernel_size=3,
+ groups=1,
+ attention=False,
+ init_cfg=None,
+ **kwargs):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ self.block = block
+ self.expansion = get_expansion(block, expansion)
+
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ in_channels = out_channels
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ else: # downsample_first=False is for HourglassModule
+ for i in range(0, num_blocks - 1):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=kernel_size,
+ groups=groups,
+ attention=attention,
+ **kwargs))
+
+ super().__init__(*layers, init_cfg=init_cfg)
+
+
+@MODELS.register_module()
+class ViPNAS_ResNet(BaseBackbone):
+ """ViPNAS_ResNet backbone.
+
+ "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search"
+ More details can be found in the `paper
+ `__ .
+
+ Args:
+ depth (int): Network depth, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ wid (list(int)): Searched width config for each stage.
+ expan (list(int)): Searched expansion ratio config for each stage.
+ dep (list(int)): Searched depth config for each stage.
+ ks (list(int)): Searched kernel size config for each stage.
+ group (list(int)): Searched group number config for each stage.
+ att (list(bool)): Searched attention config for each stage.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default:
+ ``[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]``
+ """
+
+ arch_settings = {
+ 50: ViPNAS_Bottleneck,
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ wid=[48, 80, 160, 304, 608],
+ expan=[None, 1, 1, 1, 1],
+ dep=[None, 4, 6, 7, 3],
+ ks=[7, 3, 5, 5, 5],
+ group=[None, 16, 16, 16, 16],
+ att=[None, True, False, True, True],
+ init_cfg=[
+ dict(type='Normal', std=0.001, layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ super().__init__(init_cfg=init_cfg)
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = dep[0]
+ self.num_stages = num_stages
+ assert 1 <= num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.block = self.arch_settings[depth]
+ self.stage_blocks = dep[1:1 + num_stages]
+
+ self._make_stem_layer(in_channels, wid[0], ks[0])
+
+ self.res_layers = []
+ _in_channels = wid[0]
+ for i, num_blocks in enumerate(self.stage_blocks):
+ expansion = get_expansion(self.block, expan[i + 1])
+ _out_channels = wid[i + 1] * expansion
+ stride = strides[i]
+ dilation = dilations[i]
+ res_layer = self.make_res_layer(
+ block=self.block,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=_out_channels,
+ expansion=expansion,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ kernel_size=ks[i + 1],
+ groups=group[i + 1],
+ attention=att[i + 1])
+ _in_channels = _out_channels
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = res_layer[-1].out_channels
+
+ def make_res_layer(self, **kwargs):
+ """Make a ViPNAS ResLayer."""
+ return ViPNAS_ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels, kernel_size):
+ """Make stem layer."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=kernel_size // 2,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/builder.py b/mmpose/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cefaedc29100bcbc4c5b9cde55db8f66b25ab637
--- /dev/null
+++ b/mmpose/models/builder.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+from mmpose.registry import MODELS
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+POSE_ESTIMATORS = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_pose_estimator(cfg):
+ """Build pose estimator."""
+ return POSE_ESTIMATORS.build(cfg)
+
+
+def build_posenet(cfg):
+ """Build posenet."""
+ warnings.warn(
+ '``build_posenet`` will be deprecated soon, '
+ 'please use ``build_pose_estimator`` instead.', DeprecationWarning)
+ return build_pose_estimator(cfg)
diff --git a/mmpose/models/data_preprocessors/__init__.py b/mmpose/models/data_preprocessors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c9bd22e2b20be84a17d05ab3058efd8d934f261
--- /dev/null
+++ b/mmpose/models/data_preprocessors/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .data_preprocessor import PoseDataPreprocessor
+
+__all__ = ['PoseDataPreprocessor']
diff --git a/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc b/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a2f2ab54cd6fe7d1b1c2e81d614ae7169be517b
Binary files /dev/null and b/mmpose/models/data_preprocessors/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc b/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e73f1ee78878c0e38595df4abf05639bbb30a06
Binary files /dev/null and b/mmpose/models/data_preprocessors/__pycache__/data_preprocessor.cpython-38.pyc differ
diff --git a/mmpose/models/data_preprocessors/data_preprocessor.py b/mmpose/models/data_preprocessors/data_preprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcfe54ab59108cbd033b9186b8b6aae0144677d9
--- /dev/null
+++ b/mmpose/models/data_preprocessors/data_preprocessor.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.model import ImgDataPreprocessor
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class PoseDataPreprocessor(ImgDataPreprocessor):
+ """Image pre-processor for pose estimation tasks."""
diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b4d988a5f85bad9bb60728b4a020ba75970a8a8
--- /dev/null
+++ b/mmpose/models/heads/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_head import BaseHead
+from .coord_cls_heads import RTMCCHead, SimCCHead
+from .heatmap_heads import (AssociativeEmbeddingHead, CIDHead, CPMHead,
+ HeatmapHead, MSPNHead, ViPNASHead)
+from .hybrid_heads import DEKRHead
+from .regression_heads import (DSNTHead, IntegralRegressionHead,
+ RegressionHead, RLEHead)
+
+__all__ = [
+ 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead',
+ 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead',
+ 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead'
+]
diff --git a/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5df3e265f8c637082b69aeb8bf99bb78df0080d
Binary files /dev/null and b/mmpose/models/heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc b/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e25d444bc1607e10e32c56ce604e9303eb621bb
Binary files /dev/null and b/mmpose/models/heads/__pycache__/base_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/base_head.py b/mmpose/models/heads/base_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..14882db2439d1ff334f066939f5b0cb082b4d0ea
--- /dev/null
+++ b/mmpose/models/heads/base_head.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from typing import Tuple, Union
+
+from mmengine.model import BaseModule
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (Features, InstanceList, OptConfigType,
+ OptSampleList, Predictions)
+
+
+class BaseHead(BaseModule, metaclass=ABCMeta):
+ """Base head. A subclass should override :meth:`predict` and :meth:`loss`.
+
+ Args:
+ init_cfg (dict, optional): The extra init config of layers.
+ Defaults to None.
+ """
+
+ @abstractmethod
+ def forward(self, feats: Tuple[Tensor]):
+ """Forward the network."""
+
+ @abstractmethod
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {}) -> Predictions:
+ """Predict results from features."""
+
+ @abstractmethod
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ def decode(self, batch_outputs: Union[Tensor,
+ Tuple[Tensor]]) -> InstanceList:
+ """Decode keypoints from outputs.
+
+ Args:
+ batch_outputs (Tensor | Tuple[Tensor]): The network outputs of
+ a data batch
+
+ Returns:
+ List[InstanceData]: A list of InstanceData, each contains the
+ decoded pose information of the instances of one data sample.
+ """
+
+ def _pack_and_call(args, func):
+ if not isinstance(args, tuple):
+ args = (args, )
+ return func(*args)
+
+ if self.decoder is None:
+ raise RuntimeError(
+ f'The decoder has not been set in {self.__class__.__name__}. '
+ 'Please set the decoder configs in the init parameters to '
+ 'enable head methods `head.predict()` and `head.decode()`')
+
+ if self.decoder.support_batch_decoding:
+ batch_keypoints, batch_scores = _pack_and_call(
+ batch_outputs, self.decoder.batch_decode)
+
+ else:
+ batch_output_np = to_numpy(batch_outputs, unzip=True)
+ batch_keypoints = []
+ batch_scores = []
+ for outputs in batch_output_np:
+ keypoints, scores = _pack_and_call(outputs,
+ self.decoder.decode)
+ batch_keypoints.append(keypoints)
+ batch_scores.append(scores)
+
+ preds = [
+ InstanceData(keypoints=keypoints, keypoint_scores=scores)
+ for keypoints, scores in zip(batch_keypoints, batch_scores)
+ ]
+
+ return preds
diff --git a/mmpose/models/heads/coord_cls_heads/__init__.py b/mmpose/models/heads/coord_cls_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..104ff9130898956af54de6d243fcd4a16167f38b
--- /dev/null
+++ b/mmpose/models/heads/coord_cls_heads/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .rtmcc_head import RTMCCHead
+from .simcc_head import SimCCHead
+
+__all__ = ['SimCCHead', 'RTMCCHead']
diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5541e59c8e006ab9220cb2ab1ecd200490d94bd
Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fff41e361acf3e2b09e6336bb4904b581f535a6
Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/rtmcc_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc b/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffa9bd2c9a58693c71a67b51d37e27a09f428f00
Binary files /dev/null and b/mmpose/models/heads/coord_cls_heads/__pycache__/simcc_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/coord_cls_heads/rtmcc_head.py b/mmpose/models/heads/coord_cls_heads/rtmcc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..94d613192c089e2beca14afb832ee4370f8706cf
--- /dev/null
+++ b/mmpose/models/heads/coord_cls_heads/rtmcc_head.py
@@ -0,0 +1,303 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+from mmengine.dist import get_dist_info
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.codecs.utils import get_simcc_normalized
+from mmpose.evaluation.functional import simcc_pck_accuracy
+from mmpose.models.utils.rtmcc_block import RTMCCBlock, ScaleNorm
+from mmpose.models.utils.tta import flip_vectors
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
+ OptSampleList)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class RTMCCHead(BaseHead):
+ """Top-down head introduced in RTMPose (2023). The head is composed of a
+ large-kernel convolutional layer, a fully-connected layer and a Gated
+ Attention Unit to generate 1d representation from low-resolution feature
+ maps.
+
+ Args:
+ in_channels (int | sequence[int]): Number of channels in the input
+ feature map.
+ out_channels (int): Number of channels in the output heatmap.
+ input_size (tuple): Size of input image in shape [w, h].
+ in_featuremap_size (int | sequence[int]): Size of input feature map.
+ simcc_split_ratio (float): Split ratio of pixels.
+ Default: 2.0.
+ final_layer_kernel_size (int): Kernel size of the convolutional layer.
+ Default: 1.
+ gau_cfg (Config): Config dict for the Gated Attention Unit.
+ Default: dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='ReLU',
+ use_rel_bias=False,
+ pos_enc=False).
+ loss (Config): Config of the keypoint loss. Defaults to use
+ :class:`KLDiscretLoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+ """
+
+ def __init__(
+ self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ input_size: Tuple[int, int],
+ in_featuremap_size: Tuple[int, int],
+ simcc_split_ratio: float = 2.0,
+ final_layer_kernel_size: int = 1,
+ gau_cfg: ConfigType = dict(
+ hidden_dims=256,
+ s=128,
+ expansion_factor=2,
+ dropout_rate=0.,
+ drop_path=0.,
+ act_fn='ReLU',
+ use_rel_bias=False,
+ pos_enc=False),
+ loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None,
+ ):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.input_size = input_size
+ self.in_featuremap_size = in_featuremap_size
+ self.simcc_split_ratio = simcc_split_ratio
+
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ if isinstance(in_channels, (tuple, list)):
+ raise ValueError(
+ f'{self.__class__.__name__} does not support selecting '
+ 'multiple input features.')
+
+ # Define SimCC layers
+ flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1]
+
+ self.final_layer = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=final_layer_kernel_size,
+ stride=1,
+ padding=final_layer_kernel_size // 2)
+ self.mlp = nn.Sequential(
+ ScaleNorm(flatten_dims),
+ nn.Linear(flatten_dims, gau_cfg['hidden_dims'], bias=False))
+
+ W = int(self.input_size[0] * self.simcc_split_ratio)
+ H = int(self.input_size[1] * self.simcc_split_ratio)
+
+ self.gau = RTMCCBlock(
+ self.out_channels,
+ gau_cfg['hidden_dims'],
+ gau_cfg['hidden_dims'],
+ s=gau_cfg['s'],
+ expansion_factor=gau_cfg['expansion_factor'],
+ dropout_rate=gau_cfg['dropout_rate'],
+ drop_path=gau_cfg['drop_path'],
+ attn_type='self-attn',
+ act_fn=gau_cfg['act_fn'],
+ use_rel_bias=gau_cfg['use_rel_bias'],
+ pos_enc=gau_cfg['pos_enc'])
+
+ self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False)
+ self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False)
+
+ def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
+ """Forward the network.
+
+ The input is multi scale feature maps and the
+ output is the heatmap.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ pred_x (Tensor): 1d representation of x.
+ pred_y (Tensor): 1d representation of y.
+ """
+ feats = feats[-1]
+
+ feats = self.final_layer(feats) # -> B, K, H, W
+
+ # flatten the output heatmap
+ feats = torch.flatten(feats, 2)
+
+ feats = self.mlp(feats) # -> B, K, hidden
+
+ feats = self.gau(feats)
+
+ pred_x = self.cls_x(feats)
+ pred_y = self.cls_y(feats)
+
+ return pred_x, pred_y
+
+ def predict(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {},
+ ) -> InstanceList:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ List[InstanceData]: The pose predictions, each contains
+ the following fields:
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+ - keypoint_x_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the x direction
+ - keypoint_y_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the y direction
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+
+ _batch_pred_x, _batch_pred_y = self.forward(_feats)
+
+ _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip)
+ _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors(
+ _batch_pred_x_flip,
+ _batch_pred_y_flip,
+ flip_indices=flip_indices)
+
+ batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5
+ batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5
+ else:
+ batch_pred_x, batch_pred_y = self.forward(feats)
+
+ preds = self.decode((batch_pred_x, batch_pred_y))
+
+ if test_cfg.get('output_heatmaps', False):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ warnings.warn('The predicted simcc values are normalized for '
+ 'visualization. This may cause discrepancy '
+ 'between the keypoint scores and the 1D heatmaps'
+ '.')
+
+ # normalize the predicted 1d distribution
+ batch_pred_x = get_simcc_normalized(batch_pred_x)
+ batch_pred_y = get_simcc_normalized(batch_pred_y)
+
+ B, K, _ = batch_pred_x.shape
+ # B, K, Wx -> B, K, Wx, 1
+ x = batch_pred_x.reshape(B, K, 1, -1)
+ # B, K, Wy -> B, K, 1, Wy
+ y = batch_pred_y.reshape(B, K, -1, 1)
+ # B, K, Wx, Wy
+ batch_heatmaps = torch.matmul(y, x)
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+
+ for pred_instances, pred_x, pred_y in zip(preds,
+ to_numpy(batch_pred_x),
+ to_numpy(batch_pred_y)):
+
+ pred_instances.keypoint_x_labels = pred_x[None]
+ pred_instances.keypoint_y_labels = pred_y[None]
+
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {},
+ ) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_x, pred_y = self.forward(feats)
+
+ gt_x = torch.cat([
+ d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples
+ ],
+ dim=0)
+ gt_y = torch.cat([
+ d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples
+ ],
+ dim=0)
+ keypoint_weights = torch.cat(
+ [
+ d.gt_instance_labels.keypoint_weights
+ for d in batch_data_samples
+ ],
+ dim=0,
+ )
+
+ pred_simcc = (pred_x, pred_y)
+ gt_simcc = (gt_x, gt_y)
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights)
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = simcc_pck_accuracy(
+ output=to_numpy(pred_simcc),
+ target=to_numpy(gt_simcc),
+ simcc_split_ratio=self.simcc_split_ratio,
+ mask=to_numpy(keypoint_weights) > 0,
+ )
+
+ acc_pose = torch.tensor(avg_acc, device=gt_x.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(type='Normal', layer=['Conv2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1),
+ dict(type='Normal', layer=['Linear'], std=0.01, bias=0),
+ ]
+ return init_cfg
diff --git a/mmpose/models/heads/coord_cls_heads/simcc_head.py b/mmpose/models/heads/coord_cls_heads/simcc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9287b7204ac1ca541be62aaf9e43a7d4f472210
--- /dev/null
+++ b/mmpose/models/heads/coord_cls_heads/simcc_head.py
@@ -0,0 +1,369 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+from mmcv.cnn import build_conv_layer
+from mmengine.dist import get_dist_info
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.codecs.utils import get_simcc_normalized
+from mmpose.evaluation.functional import simcc_pck_accuracy
+from mmpose.models.utils.tta import flip_vectors
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
+ OptSampleList)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class SimCCHead(BaseHead):
+ """Top-down heatmap head introduced in `SimCC`_ by Li et al (2022). The
+ head is composed of a few deconvolutional layers followed by a fully-
+ connected layer to generate 1d representation from low-resolution feature
+ maps.
+
+ Args:
+ in_channels (int | sequence[int]): Number of channels in the input
+ feature map
+ out_channels (int): Number of channels in the output heatmap
+ input_size (tuple): Input image size in shape [w, h]
+ in_featuremap_size (int | sequence[int]): Size of input feature map
+ simcc_split_ratio (float): Split ratio of pixels
+ deconv_type (str, optional): The type of deconv head which should
+ be one of the following options:
+
+ - ``'heatmap'``: make deconv layers in `HeatmapHead`
+ - ``'vipnas'``: make deconv layers in `ViPNASHead`
+
+ Defaults to ``'Heatmap'``
+ deconv_out_channels (sequence[int]): The output channel number of each
+ deconv layer. Defaults to ``(256, 256, 256)``
+ deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.Defaults to
+ ``(4, 4, 4)``
+ deconv_num_groups (Sequence[int], optional): The group number of each
+ deconv layer. Defaults to ``(16, 16, 16)``
+ conv_out_channels (sequence[int], optional): The output channel number
+ of each intermediate conv layer. ``None`` means no intermediate
+ conv layer between deconv layers and the final conv layer.
+ Defaults to ``None``
+ conv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each intermediate conv layer. Defaults to ``None``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config): Config of the keypoint loss. Defaults to use
+ :class:`KLDiscretLoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`SimCC`: https://arxiv.org/abs/2107.03332
+ """
+
+ _version = 2
+
+ def __init__(
+ self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ input_size: Tuple[int, int],
+ in_featuremap_size: Tuple[int, int],
+ simcc_split_ratio: float = 2.0,
+ deconv_type: str = 'heatmap',
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ deconv_num_groups: OptIntSeq = (16, 16, 16),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None,
+ ):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ if deconv_type not in {'heatmap', 'vipnas'}:
+ raise ValueError(
+ f'{self.__class__.__name__} got invalid `deconv_type` value'
+ f'{deconv_type}. Should be one of '
+ '{"heatmap", "vipnas"}')
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.input_size = input_size
+ self.in_featuremap_size = in_featuremap_size
+ self.simcc_split_ratio = simcc_split_ratio
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ num_deconv = len(deconv_out_channels) if deconv_out_channels else 0
+ if num_deconv != 0:
+ self.heatmap_size = tuple(
+ [s * (2**num_deconv) for s in in_featuremap_size])
+
+ # deconv layers + 1x1 conv
+ self.deconv_head = self._make_deconv_head(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ deconv_type=deconv_type,
+ deconv_out_channels=deconv_out_channels,
+ deconv_kernel_sizes=deconv_kernel_sizes,
+ deconv_num_groups=deconv_num_groups,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer)
+
+ if final_layer is not None:
+ in_channels = out_channels
+ else:
+ in_channels = deconv_out_channels[-1]
+
+ else:
+ self.deconv_head = None
+
+ if final_layer is not None:
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1)
+ cfg.update(final_layer)
+ self.final_layer = build_conv_layer(cfg)
+ else:
+ self.final_layer = None
+
+ self.heatmap_size = in_featuremap_size
+
+ # Define SimCC layers
+ flatten_dims = self.heatmap_size[0] * self.heatmap_size[1]
+
+ W = int(self.input_size[0] * self.simcc_split_ratio)
+ H = int(self.input_size[1] * self.simcc_split_ratio)
+
+ self.mlp_head_x = nn.Linear(flatten_dims, W)
+ self.mlp_head_y = nn.Linear(flatten_dims, H)
+
+ def _make_deconv_head(
+ self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ deconv_type: str = 'heatmap',
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ deconv_num_groups: OptIntSeq = (16, 16, 16),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1)
+ ) -> nn.Module:
+ """Create deconvolutional layers by given parameters."""
+
+ if deconv_type == 'heatmap':
+ deconv_head = MODELS.build(
+ dict(
+ type='HeatmapHead',
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ deconv_out_channels=deconv_out_channels,
+ deconv_kernel_sizes=deconv_kernel_sizes,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer))
+ else:
+ deconv_head = MODELS.build(
+ dict(
+ type='ViPNASHead',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ deconv_out_channels=deconv_out_channels,
+ deconv_num_groups=deconv_num_groups,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer))
+
+ return deconv_head
+
+ def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the heatmap.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ pred_x (Tensor): 1d representation of x.
+ pred_y (Tensor): 1d representation of y.
+ """
+ if self.deconv_head is None:
+ feats = feats[-1]
+ if self.final_layer is not None:
+ feats = self.final_layer(feats)
+ else:
+ feats = self.deconv_head(feats)
+
+ # flatten the output heatmap
+ x = torch.flatten(feats, 2)
+
+ pred_x = self.mlp_head_x(x)
+ pred_y = self.mlp_head_y(x)
+
+ return pred_x, pred_y
+
+ def predict(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {},
+ ) -> InstanceList:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ List[InstanceData]: The pose predictions, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+ - keypoint_x_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the x direction
+ - keypoint_y_labels (np.ndarray, optional): The predicted 1-D
+ intensity distribution in the y direction
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+
+ _batch_pred_x, _batch_pred_y = self.forward(_feats)
+
+ _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip)
+ _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors(
+ _batch_pred_x_flip,
+ _batch_pred_y_flip,
+ flip_indices=flip_indices)
+
+ batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5
+ batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5
+ else:
+ batch_pred_x, batch_pred_y = self.forward(feats)
+
+ preds = self.decode((batch_pred_x, batch_pred_y))
+
+ if test_cfg.get('output_heatmaps', False):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ warnings.warn('The predicted simcc values are normalized for '
+ 'visualization. This may cause discrepancy '
+ 'between the keypoint scores and the 1D heatmaps'
+ '.')
+
+ # normalize the predicted 1d distribution
+ sigma = self.decoder.sigma
+ batch_pred_x = get_simcc_normalized(batch_pred_x, sigma[0])
+ batch_pred_y = get_simcc_normalized(batch_pred_y, sigma[1])
+
+ B, K, _ = batch_pred_x.shape
+ # B, K, Wx -> B, K, Wx, 1
+ x = batch_pred_x.reshape(B, K, 1, -1)
+ # B, K, Wy -> B, K, 1, Wy
+ y = batch_pred_y.reshape(B, K, -1, 1)
+ # B, K, Wx, Wy
+ batch_heatmaps = torch.matmul(y, x)
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+
+ for pred_instances, pred_x, pred_y in zip(preds,
+ to_numpy(batch_pred_x),
+ to_numpy(batch_pred_y)):
+
+ pred_instances.keypoint_x_labels = pred_x[None]
+ pred_instances.keypoint_y_labels = pred_y[None]
+
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(
+ self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {},
+ ) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_x, pred_y = self.forward(feats)
+
+ gt_x = torch.cat([
+ d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples
+ ],
+ dim=0)
+ gt_y = torch.cat([
+ d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples
+ ],
+ dim=0)
+ keypoint_weights = torch.cat(
+ [
+ d.gt_instance_labels.keypoint_weights
+ for d in batch_data_samples
+ ],
+ dim=0,
+ )
+
+ pred_simcc = (pred_x, pred_y)
+ gt_simcc = (gt_x, gt_y)
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights)
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = simcc_pck_accuracy(
+ output=to_numpy(pred_simcc),
+ target=to_numpy(gt_simcc),
+ simcc_split_ratio=self.simcc_split_ratio,
+ mask=to_numpy(keypoint_weights) > 0,
+ )
+
+ acc_pose = torch.tensor(avg_acc, device=gt_x.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(
+ type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1),
+ dict(type='Normal', layer=['Linear'], std=0.01, bias=0),
+ ]
+ return init_cfg
diff --git a/mmpose/models/heads/heatmap_heads/__init__.py b/mmpose/models/heads/heatmap_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b482216b36f61ceb66aae8974ae178a8455d5022
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ae_head import AssociativeEmbeddingHead
+from .cid_head import CIDHead
+from .cpm_head import CPMHead
+from .heatmap_head import HeatmapHead
+from .mspn_head import MSPNHead
+from .vipnas_head import ViPNASHead
+
+__all__ = [
+ 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead',
+ 'AssociativeEmbeddingHead', 'CIDHead'
+]
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06f989363ba4a109f08fe2c797f0aa2faa817ca5
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..113225af5bbeff20f60677c1c806dc498c3adf1c
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/ae_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6674128484615944762f461f506bdc2911336969
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/cid_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44883ef860dda1bd001a10eb60bd4588fd5dc56c
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/cpm_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a0b066f379c0fe1630dae375181429f7af532c8
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/heatmap_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..232487516041aac335ad248d1004010ebeac6b1e
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/mspn_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc b/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58ee376a83c13788611754e5e93e8d955f39b442
Binary files /dev/null and b/mmpose/models/heads/heatmap_heads/__pycache__/vipnas_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/heatmap_heads/ae_head.py b/mmpose/models/heads/heatmap_heads/ae_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd12d57a333bbab91dfdcb11afddd0cbd351b945
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/ae_head.py
@@ -0,0 +1,291 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+from mmengine.structures import PixelData
+from mmengine.utils import is_list_of
+from torch import Tensor
+
+from mmpose.models.utils.tta import aggregate_heatmaps, flip_heatmaps
+from mmpose.registry import MODELS
+from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
+ OptSampleList, Predictions)
+from .heatmap_head import HeatmapHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class AssociativeEmbeddingHead(HeatmapHead):
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ num_keypoints: int,
+ tag_dim: int = 1,
+ tag_per_keypoint: bool = True,
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ keypoint_loss: ConfigType = dict(type='KeypointMSELoss'),
+ tag_loss: ConfigType = dict(type='AssociativeEmbeddingLoss'),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if tag_per_keypoint:
+ out_channels = num_keypoints * (1 + tag_dim)
+ else:
+ out_channels = num_keypoints + tag_dim
+
+ loss = dict(
+ type='CombinedLoss',
+ losses=dict(keypoint_loss=keypoint_loss, tag_loss=tag_loss))
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ deconv_out_channels=deconv_out_channels,
+ deconv_kernel_sizes=deconv_kernel_sizes,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer,
+ loss=loss,
+ decoder=decoder,
+ init_cfg=init_cfg)
+
+ self.num_keypoints = num_keypoints
+ self.tag_dim = tag_dim
+ self.tag_per_keypoint = tag_per_keypoint
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features.
+
+ Args:
+ feats (Features): The features which could be in following forms:
+
+ - Tuple[Tensor]: multi-stage features from the backbone
+ - List[Tuple[Tensor]]: multiple features for TTA where either
+ `flip_test` or `multiscale_test` is applied
+ - List[List[Tuple[Tensor]]]: multiple features for TTA where
+ both `flip_test` and `multiscale_test` are applied
+
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+ # test configs
+ multiscale_test = test_cfg.get('multiscale_test', False)
+ flip_test = test_cfg.get('flip_test', False)
+ shift_heatmap = test_cfg.get('shift_heatmap', False)
+ align_corners = test_cfg.get('align_corners', False)
+ restore_heatmap_size = test_cfg.get('restore_heatmap_size', False)
+ output_heatmaps = test_cfg.get('output_heatmaps', False)
+
+ # enable multi-scale test
+ if multiscale_test:
+ # TTA: multi-scale test
+ assert is_list_of(feats, list if flip_test else tuple)
+ else:
+ assert is_list_of(feats, tuple if flip_test else Tensor)
+ feats = [feats]
+
+ # resize heatmaps to align with with input size
+ if restore_heatmap_size:
+ img_shape = batch_data_samples[0].metainfo['img_shape']
+ assert all(d.metainfo['img_shape'] == img_shape
+ for d in batch_data_samples)
+ img_h, img_w = img_shape
+ heatmap_size = (img_w, img_h)
+ else:
+ heatmap_size = None
+
+ multiscale_heatmaps = []
+ multiscale_tags = []
+
+ for scale_idx, _feats in enumerate(feats):
+ if not flip_test:
+ _heatmaps, _tags = self.forward(_feats)
+
+ else:
+ # TTA: flip test
+ assert isinstance(_feats, list) and len(_feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ # original
+ _feats_orig, _feats_flip = _feats
+ _heatmaps_orig, _tags_orig = self.forward(_feats_orig)
+
+ # flipped
+ _heatmaps_flip, _tags_flip = self.forward(_feats_flip)
+ _heatmaps_flip = flip_heatmaps(
+ _heatmaps_flip,
+ flip_mode='heatmap',
+ flip_indices=flip_indices,
+ shift_heatmap=shift_heatmap)
+ _tags_flip = self._flip_tags(
+ _tags_flip,
+ flip_indices=flip_indices,
+ shift_heatmap=shift_heatmap)
+
+ # aggregated heatmaps
+ _heatmaps = aggregate_heatmaps(
+ [_heatmaps_orig, _heatmaps_flip],
+ size=heatmap_size,
+ align_corners=align_corners,
+ mode='average')
+
+ # aggregated tags (only at original scale)
+ if scale_idx == 0:
+ _tags = aggregate_heatmaps([_tags_orig, _tags_flip],
+ size=heatmap_size,
+ align_corners=align_corners,
+ mode='concat')
+ else:
+ _tags = None
+
+ multiscale_heatmaps.append(_heatmaps)
+ multiscale_tags.append(_tags)
+
+ # aggregate multi-scale heatmaps
+ if len(feats) > 1:
+ batch_heatmaps = aggregate_heatmaps(
+ multiscale_heatmaps,
+ align_corners=align_corners,
+ mode='average')
+ else:
+ batch_heatmaps = multiscale_heatmaps[0]
+ # only keep tags at original scale
+ batch_tags = multiscale_tags[0]
+
+ batch_outputs = tuple([batch_heatmaps, batch_tags])
+ preds = self.decode(batch_outputs)
+
+ if output_heatmaps:
+ pred_fields = []
+ for _heatmaps, _tags in zip(batch_heatmaps.detach(),
+ batch_tags.detach()):
+ pred_fields.append(PixelData(heatmaps=_heatmaps, tags=_tags))
+
+ return preds, pred_fields
+ else:
+ return preds
+
+ def _flip_tags(self,
+ tags: Tensor,
+ flip_indices: List[int],
+ shift_heatmap: bool = True):
+ """Flip the tagging heatmaps horizontally for test-time augmentation.
+
+ Args:
+ tags (Tensor): batched tagging heatmaps to flip
+ flip_indices (List[int]): The indices of each keypoint's symmetric
+ keypoint
+ shift_heatmap (bool): Shift the flipped heatmaps to align with the
+ original heatmaps and improve accuracy. Defaults to ``True``
+
+ Returns:
+ Tensor: flipped tagging heatmaps
+ """
+ B, C, H, W = tags.shape
+ K = self.num_keypoints
+ L = self.tag_dim
+
+ tags = tags.flip(-1)
+
+ if self.tag_per_keypoint:
+ assert C == K * L
+ tags = tags.view(B, L, K, H, W)
+ tags = tags[:, :, flip_indices]
+ tags = tags.view(B, C, H, W)
+
+ if shift_heatmap:
+ tags[..., 1:] = tags[..., :-1].clone()
+
+ return tags
+
+ def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the heatmaps and tags.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ tuple:
+ - heatmaps (Tensor): output heatmaps
+ - tags (Tensor): output tags
+ """
+
+ output = super().forward(feats)
+ heatmaps = output[:, :self.num_keypoints]
+ tags = output[:, self.num_keypoints:]
+ return heatmaps, tags
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ feats (Tuple[Tensor]): The multi-stage features
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ train_cfg (dict): The runtime config for training process.
+ Defaults to {}
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+ pred_heatmaps, pred_tags = self.forward(feats)
+
+ if not self.tag_per_keypoint:
+ pred_tags = pred_tags.repeat((1, self.num_keypoints, 1, 1))
+
+ gt_heatmaps = torch.stack(
+ [d.gt_fields.heatmaps for d in batch_data_samples])
+ gt_masks = torch.stack(
+ [d.gt_fields.heatmap_mask for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+ keypoint_indices = [
+ d.gt_instance_labels.keypoint_indices for d in batch_data_samples
+ ]
+
+ loss_kpt = self.loss_module.keypoint_loss(pred_heatmaps, gt_heatmaps,
+ keypoint_weights, gt_masks)
+
+ loss_pull, loss_push = self.loss_module.tag_loss(
+ pred_tags, keypoint_indices)
+
+ losses = {
+ 'loss_kpt': loss_kpt,
+ 'loss_pull': loss_pull,
+ 'loss_push': loss_push
+ }
+
+ return losses
diff --git a/mmpose/models/heads/heatmap_heads/cid_head.py b/mmpose/models/heads/heatmap_heads/cid_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..39e0211a3e135c1c101c14e37956528d3330ca1b
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/cid_head.py
@@ -0,0 +1,743 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Dict, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer
+from mmengine.model import BaseModule, ModuleDict
+from mmengine.structures import InstanceData, PixelData
+from torch import Tensor
+
+from mmpose.models.utils.tta import flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
+ OptSampleList, Predictions)
+from ..base_head import BaseHead
+
+
+def smooth_heatmaps(heatmaps: Tensor, blur_kernel_size: int) -> Tensor:
+ """Smooth the heatmaps by blurring and averaging.
+
+ Args:
+ heatmaps (Tensor): The heatmaps to smooth.
+ blur_kernel_size (int): The kernel size for blurring the heatmaps.
+
+ Returns:
+ Tensor: The smoothed heatmaps.
+ """
+ smoothed_heatmaps = torch.nn.functional.avg_pool2d(
+ heatmaps, blur_kernel_size, 1, (blur_kernel_size - 1) // 2)
+ smoothed_heatmaps = (heatmaps + smoothed_heatmaps) / 2.0
+ return smoothed_heatmaps
+
+
+class TruncSigmoid(nn.Sigmoid):
+ """A sigmoid activation function that truncates the output to the given
+ range.
+
+ Args:
+ min (float, optional): The minimum value to clamp the output to.
+ Defaults to 0.0
+ max (float, optional): The maximum value to clamp the output to.
+ Defaults to 1.0
+ """
+
+ def __init__(self, min: float = 0.0, max: float = 1.0):
+ super(TruncSigmoid, self).__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, input: Tensor) -> Tensor:
+ """Computes the truncated sigmoid activation of the input tensor."""
+ output = torch.sigmoid(input)
+ output = output.clamp(min=self.min, max=self.max)
+ return output
+
+
+class IIAModule(BaseModule):
+ """Instance Information Abstraction module introduced in `CID`. This module
+ extracts the feature representation vectors for each instance.
+
+ Args:
+ in_channels (int): Number of channels in the input feature tensor
+ out_channels (int): Number of channels of the output heatmaps
+ clamp_delta (float, optional): A small value that prevents the sigmoid
+ activation from becoming saturated. Defaults to 1e-4.
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ clamp_delta: float = 1e-4,
+ init_cfg: OptConfigType = None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+
+ self.keypoint_root_conv = build_conv_layer(
+ dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1))
+ self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)
+
+ def forward(self, feats: Tensor):
+ heatmaps = self.keypoint_root_conv(feats)
+ heatmaps = self.sigmoid(heatmaps)
+ return heatmaps
+
+ def _sample_feats(self, feats: Tensor, indices: Tensor) -> Tensor:
+ """Extract feature vectors at the specified indices from the input
+ feature map.
+
+ Args:
+ feats (Tensor): Input feature map.
+ indices (Tensor): Indices of the feature vectors to extract.
+
+ Returns:
+ Tensor: Extracted feature vectors.
+ """
+ assert indices.dtype == torch.long
+ if indices.shape[1] == 3:
+ b, w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
+ instance_feats = feats[b, :, h, w]
+ elif indices.shape[1] == 2:
+ w, h = [ind.squeeze(-1) for ind in indices.split(1, -1)]
+ instance_feats = feats[:, :, h, w]
+ instance_feats = instance_feats.permute(0, 2, 1)
+ instance_feats = instance_feats.reshape(-1,
+ instance_feats.shape[-1])
+
+ else:
+ raise ValueError(f'`indices` should have 2 or 3 channels, '
+ f'but got f{indices.shape[1]}')
+ return instance_feats
+
+ def _hierarchical_pool(self, heatmaps: Tensor) -> Tensor:
+ """Conduct max pooling on the input heatmaps with different kernel size
+ according to the input size.
+
+ Args:
+ heatmaps (Tensor): Input heatmaps.
+
+ Returns:
+ Tensor: Result of hierarchical pooling.
+ """
+ map_size = (heatmaps.shape[-1] + heatmaps.shape[-2]) / 2.0
+ if map_size > 300:
+ maxm = torch.nn.functional.max_pool2d(heatmaps, 7, 1, 3)
+ elif map_size > 200:
+ maxm = torch.nn.functional.max_pool2d(heatmaps, 5, 1, 2)
+ else:
+ maxm = torch.nn.functional.max_pool2d(heatmaps, 3, 1, 1)
+ return maxm
+
+ def forward_train(self, feats: Tensor, instance_coords: Tensor,
+ instance_imgids: Tensor) -> Tuple[Tensor, Tensor]:
+ """Forward pass during training.
+
+ Args:
+ feats (Tensor): Input feature tensor.
+ instance_coords (Tensor): Coordinates of the instance roots.
+ instance_imgids (Tensor): Sample indices of each instances
+ in the batch.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Extracted feature vectors and heatmaps
+ for the instances.
+ """
+ heatmaps = self.forward(feats)
+ indices = torch.cat((instance_imgids[:, None], instance_coords), dim=1)
+ instance_feats = self._sample_feats(feats, indices)
+
+ return instance_feats, heatmaps
+
+ def forward_test(
+ self, feats: Tensor, test_cfg: Dict
+ ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
+ """Forward pass during testing.
+
+ Args:
+ feats (Tensor): Input feature tensor.
+ test_cfg (Dict): Testing configuration, including:
+ - blur_kernel_size (int, optional): Kernel size for blurring
+ the heatmaps. Defaults to 3.
+ - max_instances (int, optional): Maximum number of instances
+ to extract. Defaults to 30.
+ - score_threshold (float, optional): Minimum score for
+ extracting an instance. Defaults to 0.01.
+ - flip_test (bool, optional): Whether to compute the average
+ of the heatmaps across the batch dimension.
+ Defaults to False.
+
+ Returns:
+ A tuple of Tensor including extracted feature vectors,
+ coordinates, and scores of the instances. Any of these can be
+ empty Tensor if no instances are extracted.
+ """
+ blur_kernel_size = test_cfg.get('blur_kernel_size', 3)
+ max_instances = test_cfg.get('max_instances', 30)
+ score_threshold = test_cfg.get('score_threshold', 0.01)
+ H, W = feats.shape[-2:]
+
+ # compute heatmaps
+ heatmaps = self.forward(feats).narrow(1, -1, 1)
+ if test_cfg.get('flip_test', False):
+ heatmaps = heatmaps.mean(dim=0, keepdims=True)
+ smoothed_heatmaps = smooth_heatmaps(heatmaps, blur_kernel_size)
+
+ # decode heatmaps
+ maximums = self._hierarchical_pool(smoothed_heatmaps)
+ maximums = torch.eq(maximums, smoothed_heatmaps).float()
+ maximums = (smoothed_heatmaps * maximums).reshape(-1)
+ scores, pos_ind = maximums.topk(max_instances, dim=0)
+ select_ind = (scores > (score_threshold)).nonzero().squeeze(1)
+ scores, pos_ind = scores[select_ind], pos_ind[select_ind]
+
+ # sample feature vectors from feature map
+ instance_coords = torch.stack((pos_ind % W, pos_ind // W), dim=1)
+ instance_feats = self._sample_feats(feats, instance_coords)
+
+ return instance_feats, instance_coords, scores
+
+
+class ChannelAttention(nn.Module):
+ """Channel-wise attention module introduced in `CID`.
+
+ Args:
+ in_channels (int): The number of channels of the input instance
+ vectors.
+ out_channels (int): The number of channels of the transformed instance
+ vectors.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int):
+ super(ChannelAttention, self).__init__()
+ self.atn = nn.Linear(in_channels, out_channels)
+
+ def forward(self, global_feats: Tensor, instance_feats: Tensor) -> Tensor:
+ """Applies attention to the channel dimension of the input tensor."""
+
+ instance_feats = self.atn(instance_feats).unsqueeze(2).unsqueeze(3)
+ return global_feats * instance_feats
+
+
+class SpatialAttention(nn.Module):
+ """Spatial-wise attention module introduced in `CID`.
+
+ Args:
+ in_channels (int): The number of channels of the input instance
+ vectors.
+ out_channels (int): The number of channels of the transformed instance
+ vectors.
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(SpatialAttention, self).__init__()
+ self.atn = nn.Linear(in_channels, out_channels)
+ self.feat_stride = 4
+ self.conv = nn.Conv2d(3, 1, 5, 1, 2)
+
+ def _get_pixel_coords(self, heatmap_size: Tuple, device: str = 'cpu'):
+ """Get pixel coordinates for each element in the heatmap.
+
+ Args:
+ heatmap_size (tuple): Size of the heatmap in (W, H) format.
+ device (str): Device to put the resulting tensor on.
+
+ Returns:
+ Tensor of shape (batch_size, num_pixels, 2) containing the pixel
+ coordinates for each element in the heatmap.
+ """
+ w, h = heatmap_size
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
+ pixel_coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
+ pixel_coords = pixel_coords.float().to(device) + 0.5
+ return pixel_coords
+
+ def forward(self, global_feats: Tensor, instance_feats: Tensor,
+ instance_coords: Tensor) -> Tensor:
+ """Perform spatial attention.
+
+ Args:
+ global_feats (Tensor): Tensor containing the global features.
+ instance_feats (Tensor): Tensor containing the instance feature
+ vectors.
+ instance_coords (Tensor): Tensor containing the root coordinates
+ of the instances.
+
+ Returns:
+ Tensor containing the modulated global features.
+ """
+ B, C, H, W = global_feats.size()
+
+ instance_feats = self.atn(instance_feats).reshape(B, C, 1, 1)
+ feats = global_feats * instance_feats.expand_as(global_feats)
+ fsum = torch.sum(feats, dim=1, keepdim=True)
+
+ pixel_coords = self._get_pixel_coords((W, H), feats.device)
+ relative_coords = instance_coords.reshape(
+ -1, 1, 2) - pixel_coords.reshape(1, -1, 2)
+ relative_coords = relative_coords.permute(0, 2, 1) / 32.0
+ relative_coords = relative_coords.reshape(B, 2, H, W)
+
+ input_feats = torch.cat((fsum, relative_coords), dim=1)
+ mask = self.conv(input_feats).sigmoid()
+ return global_feats * mask
+
+
+class GFDModule(BaseModule):
+ """Global Feature Decoupling module introduced in `CID`. This module
+ extracts the decoupled heatmaps for each instance.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map
+ out_channels (int): Number of channels of the output heatmaps
+ for each instance
+ gfd_channels (int): Number of channels in the transformed feature map
+ clamp_delta (float, optional): A small value that prevents the sigmoid
+ activation from becoming saturated. Defaults to 1e-4.
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ gfd_channels: int,
+ clamp_delta: float = 1e-4,
+ init_cfg: OptConfigType = None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+
+ self.conv_down = build_conv_layer(
+ dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=gfd_channels,
+ kernel_size=1))
+
+ self.channel_attention = ChannelAttention(in_channels, gfd_channels)
+ self.spatial_attention = SpatialAttention(in_channels, gfd_channels)
+ self.fuse_attention = build_conv_layer(
+ dict(
+ type='Conv2d',
+ in_channels=gfd_channels * 2,
+ out_channels=gfd_channels,
+ kernel_size=1))
+ self.heatmap_conv = build_conv_layer(
+ dict(
+ type='Conv2d',
+ in_channels=gfd_channels,
+ out_channels=out_channels,
+ kernel_size=1))
+ self.sigmoid = TruncSigmoid(min=clamp_delta, max=1 - clamp_delta)
+
+ def forward(
+ self,
+ feats: Tensor,
+ instance_feats: Tensor,
+ instance_coords: Tensor,
+ instance_imgids: Tensor,
+ ) -> Tensor:
+ """Extract decoupled heatmaps for each instance.
+
+ Args:
+ feats (Tensor): Input feature maps.
+ instance_feats (Tensor): Tensor containing the instance feature
+ vectors.
+ instance_coords (Tensor): Tensor containing the root coordinates
+ of the instances.
+ instance_imgids (Tensor): Sample indices of each instances
+ in the batch.
+
+ Returns:
+ A tensor containing decoupled heatmaps.
+ """
+
+ global_feats = self.conv_down(feats)
+ global_feats = global_feats[instance_imgids]
+ cond_instance_feats = torch.cat(
+ (self.channel_attention(global_feats, instance_feats),
+ self.spatial_attention(global_feats, instance_feats,
+ instance_coords)),
+ dim=1)
+
+ cond_instance_feats = self.fuse_attention(cond_instance_feats)
+ cond_instance_feats = torch.nn.functional.relu(cond_instance_feats)
+ cond_instance_feats = self.heatmap_conv(cond_instance_feats)
+ heatmaps = self.sigmoid(cond_instance_feats)
+
+ return heatmaps
+
+
+@MODELS.register_module()
+class CIDHead(BaseHead):
+ """Contextual Instance Decoupling head introduced in `Contextual Instance
+ Decoupling for Robust Multi-Person Pose Estimation (CID)`_ by Wang et al
+ (2022). The head is composed of an Instance Information Abstraction (IIA)
+ module and a Global Feature Decoupling (GFD) module.
+
+ Args:
+ in_channels (int | Sequence[int]): Number of channels in the input
+ feature map
+ num_keypoints (int): Number of keypoints
+ gfd_channels (int): Number of filters in GFD module
+ max_train_instances (int): Maximum number of instances in a batch
+ during training. Defaults to 200
+ heatmap_loss (Config): Config of the heatmap loss. Defaults to use
+ :class:`KeypointMSELoss`
+ coupled_heatmap_loss (Config): Config of the loss for coupled heatmaps.
+ Defaults to use :class:`SoftWeightSmoothL1Loss`
+ decoupled_heatmap_loss (Config): Config of the loss for decoupled
+ heatmaps. Defaults to use :class:`SoftWeightSmoothL1Loss`
+ contrastive_loss (Config): Config of the contrastive loss for
+ representation vectors of instances. Defaults to use
+ :class:`InfoNCELoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`CID`: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_
+ Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_
+ CVPR_2022_paper.html
+ """
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ gfd_channels: int,
+ num_keypoints: int,
+ prior_prob: float = 0.01,
+ coupled_heatmap_loss: OptConfigType = dict(
+ type='FocalHeatmapLoss'),
+ decoupled_heatmap_loss: OptConfigType = dict(
+ type='FocalHeatmapLoss'),
+ contrastive_loss: OptConfigType = dict(type='InfoNCELoss'),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.num_keypoints = num_keypoints
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # build sub-modules
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.iia_module = IIAModule(
+ in_channels,
+ num_keypoints + 1,
+ init_cfg=init_cfg + [
+ dict(
+ type='Normal',
+ layer=['Conv2d', 'Linear'],
+ std=0.001,
+ override=dict(
+ name='keypoint_root_conv',
+ type='Normal',
+ std=0.001,
+ bias=bias_value))
+ ])
+ self.gfd_module = GFDModule(
+ in_channels,
+ num_keypoints,
+ gfd_channels,
+ init_cfg=init_cfg + [
+ dict(
+ type='Normal',
+ layer=['Conv2d', 'Linear'],
+ std=0.001,
+ override=dict(
+ name='heatmap_conv',
+ type='Normal',
+ std=0.001,
+ bias=bias_value))
+ ])
+
+ # build losses
+ self.loss_module = ModuleDict(
+ dict(
+ heatmap_coupled=MODELS.build(coupled_heatmap_loss),
+ heatmap_decoupled=MODELS.build(decoupled_heatmap_loss),
+ contrastive=MODELS.build(contrastive_loss),
+ ))
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(type='Normal', layer=['Conv2d', 'Linear'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1)
+ ]
+ return init_cfg
+
+ def forward(self, feats: Tuple[Tensor]) -> Tensor:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the heatmap.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tensor: output heatmap.
+ """
+ feats = feats[-1]
+ instance_info = self.iia_module.forward_test(feats, {})
+ instance_feats, instance_coords, instance_scores = instance_info
+ instance_imgids = torch.zeros(
+ instance_coords.size(0), dtype=torch.long, device=feats.device)
+ instance_heatmaps = self.gfd_module(feats, instance_feats,
+ instance_coords, instance_imgids)
+
+ return instance_heatmaps
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+ metainfo = batch_data_samples[0].metainfo
+
+ if test_cfg.get('flip_test', False):
+ assert isinstance(feats, list) and len(feats) == 2
+
+ feats_flipped = flip_heatmaps(feats[1][-1], shift_heatmap=False)
+ feats = torch.cat((feats[0][-1], feats_flipped))
+ else:
+ feats = feats[-1]
+
+ instance_info = self.iia_module.forward_test(feats, test_cfg)
+ instance_feats, instance_coords, instance_scores = instance_info
+ if len(instance_coords) > 0:
+ instance_imgids = torch.zeros(
+ instance_coords.size(0), dtype=torch.long, device=feats.device)
+ if test_cfg.get('flip_test', False):
+ instance_coords = torch.cat((instance_coords, instance_coords))
+ instance_imgids = torch.cat(
+ (instance_imgids, instance_imgids + 1))
+ instance_heatmaps = self.gfd_module(feats, instance_feats,
+ instance_coords,
+ instance_imgids)
+ if test_cfg.get('flip_test', False):
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ instance_heatmaps, instance_heatmaps_flip = torch.chunk(
+ instance_heatmaps, 2, dim=0)
+ instance_heatmaps_flip = \
+ instance_heatmaps_flip[:, flip_indices, :, :]
+ instance_heatmaps = (instance_heatmaps +
+ instance_heatmaps_flip) / 2.0
+ instance_heatmaps = smooth_heatmaps(
+ instance_heatmaps, test_cfg.get('blur_kernel_size', 3))
+
+ preds = self.decode((instance_heatmaps, instance_scores[:, None]))
+ preds = InstanceData.cat(preds)
+ preds.keypoints[..., 0] += metainfo['input_size'][
+ 0] / instance_heatmaps.shape[-1] / 2.0
+ preds.keypoints[..., 1] += metainfo['input_size'][
+ 1] / instance_heatmaps.shape[-2] / 2.0
+ preds = [preds]
+
+ else:
+ preds = [
+ InstanceData(
+ keypoints=np.empty((0, self.num_keypoints, 2)),
+ keypoint_scores=np.empty((0, self.num_keypoints)))
+ ]
+ instance_heatmaps = torch.empty(0, self.num_keypoints,
+ *feats.shape[-2:])
+
+ if test_cfg.get('output_heatmaps', False):
+ pred_fields = [
+ PixelData(
+ heatmaps=instance_heatmaps.reshape(
+ -1, *instance_heatmaps.shape[-2:]))
+ ]
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ feats (Tuple[Tensor]): The multi-stage features
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ train_cfg (dict): The runtime config for training process.
+ Defaults to {}
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+
+ # load targets
+ gt_heatmaps, gt_instance_coords, keypoint_weights = [], [], []
+ heatmap_mask = []
+ instance_imgids, gt_instance_heatmaps = [], []
+ for i, d in enumerate(batch_data_samples):
+ gt_heatmaps.append(d.gt_fields.heatmaps)
+ gt_instance_coords.append(d.gt_instance_labels.instance_coords)
+ keypoint_weights.append(d.gt_instance_labels.keypoint_weights)
+ instance_imgids.append(
+ torch.ones(
+ len(d.gt_instance_labels.instance_coords),
+ dtype=torch.long) * i)
+
+ instance_heatmaps = d.gt_fields.instance_heatmaps.reshape(
+ -1, self.num_keypoints,
+ *d.gt_fields.instance_heatmaps.shape[1:])
+ gt_instance_heatmaps.append(instance_heatmaps)
+
+ if 'heatmap_mask' in d.gt_fields:
+ heatmap_mask.append(d.gt_fields.heatmap_mask)
+
+ gt_heatmaps = torch.stack(gt_heatmaps)
+ heatmap_mask = torch.stack(heatmap_mask) if heatmap_mask else None
+
+ gt_instance_coords = torch.cat(gt_instance_coords, dim=0)
+ gt_instance_heatmaps = torch.cat(gt_instance_heatmaps, dim=0)
+ keypoint_weights = torch.cat(keypoint_weights, dim=0)
+ instance_imgids = torch.cat(instance_imgids).to(gt_heatmaps.device)
+
+ # feed-forward
+ feats = feats[-1]
+ pred_instance_feats, pred_heatmaps = self.iia_module.forward_train(
+ feats, gt_instance_coords, instance_imgids)
+
+ # conpute contrastive loss
+ contrastive_loss = 0
+ for i in range(len(batch_data_samples)):
+ pred_instance_feat = pred_instance_feats[instance_imgids == i]
+ contrastive_loss += self.loss_module['contrastive'](
+ pred_instance_feat)
+ contrastive_loss = contrastive_loss / max(1, len(instance_imgids))
+
+ # limit the number of instances
+ max_train_instances = train_cfg.get('max_train_instances', -1)
+ if (max_train_instances > 0
+ and len(instance_imgids) > max_train_instances):
+ selected_indices = torch.randperm(
+ len(instance_imgids),
+ device=gt_heatmaps.device,
+ dtype=torch.long)[:max_train_instances]
+ gt_instance_coords = gt_instance_coords[selected_indices]
+ keypoint_weights = keypoint_weights[selected_indices]
+ gt_instance_heatmaps = gt_instance_heatmaps[selected_indices]
+ instance_imgids = instance_imgids[selected_indices]
+ pred_instance_feats = pred_instance_feats[selected_indices]
+
+ # calculate the decoupled heatmaps for each instance
+ pred_instance_heatmaps = self.gfd_module(feats, pred_instance_feats,
+ gt_instance_coords,
+ instance_imgids)
+
+ # calculate losses
+ losses = {
+ 'loss/heatmap_coupled':
+ self.loss_module['heatmap_coupled'](pred_heatmaps, gt_heatmaps,
+ None, heatmap_mask)
+ }
+ if len(instance_imgids) > 0:
+ losses.update({
+ 'loss/heatmap_decoupled':
+ self.loss_module['heatmap_decoupled'](pred_instance_heatmaps,
+ gt_instance_heatmaps,
+ keypoint_weights),
+ 'loss/contrastive':
+ contrastive_loss
+ })
+
+ return losses
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert old-version state dict of
+ :class:`CIDHead` (before MMPose v1.0.0) to a compatible format
+ of :class:`CIDHead`.
+
+ The hook will be automatically registered during initialization.
+ """
+ version = local_meta.get('version', None)
+ if version and version >= self._version:
+ return
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for k in keys:
+ if 'keypoint_center_conv' in k:
+ v = state_dict.pop(k)
+ k = k.replace('keypoint_center_conv',
+ 'iia_module.keypoint_root_conv')
+ state_dict[k] = v
+
+ if 'conv_down' in k:
+ v = state_dict.pop(k)
+ k = k.replace('conv_down', 'gfd_module.conv_down')
+ state_dict[k] = v
+
+ if 'c_attn' in k:
+ v = state_dict.pop(k)
+ k = k.replace('c_attn', 'gfd_module.channel_attention')
+ state_dict[k] = v
+
+ if 's_attn' in k:
+ v = state_dict.pop(k)
+ k = k.replace('s_attn', 'gfd_module.spatial_attention')
+ state_dict[k] = v
+
+ if 'fuse_attn' in k:
+ v = state_dict.pop(k)
+ k = k.replace('fuse_attn', 'gfd_module.fuse_attention')
+ state_dict[k] = v
+
+ if 'heatmap_conv' in k:
+ v = state_dict.pop(k)
+ k = k.replace('heatmap_conv', 'gfd_module.heatmap_conv')
+ state_dict[k] = v
diff --git a/mmpose/models/heads/heatmap_heads/cpm_head.py b/mmpose/models/heads/heatmap_heads/cpm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba46357ec5cf72b29b43635a53354f2ed2fd048
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/cpm_head.py
@@ -0,0 +1,307 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Union
+
+import torch
+from mmcv.cnn import build_conv_layer, build_upsample_layer
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import pose_pck_accuracy
+from mmpose.models.utils.tta import flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (Features, MultiConfig, OptConfigType,
+ OptSampleList, Predictions)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class CPMHead(BaseHead):
+ """Multi-stage heatmap head introduced in `Convolutional Pose Machines`_ by
+ Wei et al (2016) and used by `Stacked Hourglass Networks`_ by Newell et al
+ (2016). The head consists of multiple branches, each of which has some
+ deconv layers and a simple conv2d layer.
+
+ Args:
+ in_channels (int | Sequence[int]): Number of channels in the input
+ feature maps.
+ out_channels (int): Number of channels in the output heatmaps.
+ num_stages (int): Number of stages.
+ deconv_out_channels (Sequence[int], optional): The output channel
+ number of each deconv layer. Defaults to ``(256, 256, 256)``
+ deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.
+ Defaults to ``(4, 4, 4)``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config | List[Config]): Config of the keypoint loss of different
+ stages. Defaults to use :class:`KeypointMSELoss`.
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`Convolutional Pose Machines`: https://arxiv.org/abs/1602.00134
+ .. _`Stacked Hourglass Networks`: https://arxiv.org/abs/1603.06937
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ num_stages: int,
+ deconv_out_channels: OptIntSeq = None,
+ deconv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: MultiConfig = dict(
+ type='KeypointMSELoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+ super().__init__(init_cfg)
+
+ self.num_stages = num_stages
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if isinstance(loss, list):
+ if len(loss) != num_stages:
+ raise ValueError(
+ f'The length of loss_module({len(loss)}) did not match '
+ f'`num_stages`({num_stages})')
+ self.loss_module = nn.ModuleList(
+ MODELS.build(_loss) for _loss in loss)
+ else:
+ self.loss_module = MODELS.build(loss)
+
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # build multi-stage deconv layers
+ self.multi_deconv_layers = nn.ModuleList([])
+ if deconv_out_channels:
+ if deconv_kernel_sizes is None or len(deconv_out_channels) != len(
+ deconv_kernel_sizes):
+ raise ValueError(
+ '"deconv_out_channels" and "deconv_kernel_sizes" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {deconv_out_channels} and '
+ f'{deconv_kernel_sizes}')
+
+ for _ in range(self.num_stages):
+ deconv_layers = self._make_deconv_layers(
+ in_channels=in_channels,
+ layer_out_channels=deconv_out_channels,
+ layer_kernel_sizes=deconv_kernel_sizes,
+ )
+ self.multi_deconv_layers.append(deconv_layers)
+ in_channels = deconv_out_channels[-1]
+ else:
+ for _ in range(self.num_stages):
+ self.multi_deconv_layers.append(nn.Identity())
+
+ # build multi-stage final layers
+ self.multi_final_layers = nn.ModuleList([])
+ if final_layer is not None:
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1)
+ cfg.update(final_layer)
+ for _ in range(self.num_stages):
+ self.multi_final_layers.append(build_conv_layer(cfg))
+ else:
+ for _ in range(self.num_stages):
+ self.multi_final_layers.append(nn.Identity())
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(
+ type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1)
+ ]
+ return init_cfg
+
+ def _make_deconv_layers(self, in_channels: int,
+ layer_out_channels: Sequence[int],
+ layer_kernel_sizes: Sequence[int]) -> nn.Module:
+ """Create deconvolutional layers by given parameters."""
+
+ layers = []
+ for out_channels, kernel_size in zip(layer_out_channels,
+ layer_kernel_sizes):
+ if kernel_size == 4:
+ padding = 1
+ output_padding = 0
+ elif kernel_size == 3:
+ padding = 1
+ output_padding = 1
+ elif kernel_size == 2:
+ padding = 0
+ output_padding = 0
+ else:
+ raise ValueError(f'Unsupported kernel size {kernel_size} for'
+ 'deconvlutional layers in '
+ f'{self.__class__.__name__}')
+ cfg = dict(
+ type='deconv',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=False)
+ layers.append(build_upsample_layer(cfg))
+ layers.append(nn.BatchNorm2d(num_features=out_channels))
+ layers.append(nn.ReLU(inplace=True))
+ in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def forward(self, feats: Sequence[Tensor]) -> List[Tensor]:
+ """Forward the network. The input is multi-stage feature maps and the
+ output is a list of heatmaps from multiple stages.
+
+ Args:
+ feats (Sequence[Tensor]): Multi-stage feature maps.
+
+ Returns:
+ List[Tensor]: A list of output heatmaps from multiple stages.
+ """
+ out = []
+ assert len(feats) == self.num_stages, (
+ f'The length of feature maps did not match the '
+ f'`num_stages` in {self.__class__.__name__}')
+ for i in range(self.num_stages):
+ y = self.multi_deconv_layers[i](feats[i])
+ y = self.multi_final_layers[i](y)
+ out.append(y)
+
+ return out
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {}) -> Predictions:
+ """Predict results from multi-stage feature maps.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+ _batch_heatmaps = self.forward(_feats)[-1]
+ _batch_heatmaps_flip = flip_heatmaps(
+ self.forward(_feats_flip)[-1],
+ flip_mode=test_cfg.get('flip_mode', 'heatmap'),
+ flip_indices=flip_indices,
+ shift_heatmap=test_cfg.get('shift_heatmap', False))
+ batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
+ else:
+ multi_stage_heatmaps = self.forward(feats)
+ batch_heatmaps = multi_stage_heatmaps[-1]
+
+ preds = self.decode(batch_heatmaps)
+
+ if test_cfg.get('output_heatmaps', False):
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(self,
+ feats: Sequence[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ feats (Sequence[Tensor]): Multi-stage feature maps.
+ batch_data_samples (List[:obj:`PoseDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instances`.
+ train_cfg (Config, optional): The training config.
+
+ Returns:
+ dict: A dictionary of loss components.
+ """
+ multi_stage_pred_heatmaps = self.forward(feats)
+
+ gt_heatmaps = torch.stack(
+ [d.gt_fields.heatmaps for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+
+ # calculate losses over multiple stages
+ losses = dict()
+ for i in range(self.num_stages):
+ if isinstance(self.loss_module, nn.ModuleList):
+ # use different loss_module over different stages
+ loss_func = self.loss_module[i]
+ else:
+ # use the same loss_module over different stages
+ loss_func = self.loss_module
+
+ # the `gt_heatmaps` and `keypoint_weights` used to calculate loss
+ # for different stages are the same
+ loss_i = loss_func(multi_stage_pred_heatmaps[i], gt_heatmaps,
+ keypoint_weights)
+
+ if 'loss_kpt' not in losses:
+ losses['loss_kpt'] = loss_i
+ else:
+ losses['loss_kpt'] += loss_i
+
+ # calculate accuracy
+ _, avg_acc, _ = pose_pck_accuracy(
+ output=to_numpy(multi_stage_pred_heatmaps[-1]),
+ target=to_numpy(gt_heatmaps),
+ mask=to_numpy(keypoint_weights) > 0)
+
+ acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
diff --git a/mmpose/models/heads/heatmap_heads/heatmap_head.py b/mmpose/models/heads/heatmap_heads/heatmap_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b0fa3f475b3e54ca0dd267e486d1e0424b35ab6
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/heatmap_head.py
@@ -0,0 +1,369 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+from mmcv.cnn import build_conv_layer, build_upsample_layer
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import pose_pck_accuracy
+from mmpose.models.utils.tta import flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
+ OptSampleList, Predictions)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class HeatmapHead(BaseHead):
+ """Top-down heatmap head introduced in `Simple Baselines`_ by Xiao et al
+ (2018). The head is composed of a few deconvolutional layers followed by a
+ convolutional layer to generate heatmaps from low-resolution feature maps.
+
+ Args:
+ in_channels (int | Sequence[int]): Number of channels in the input
+ feature map
+ out_channels (int): Number of channels in the output heatmap
+ deconv_out_channels (Sequence[int], optional): The output channel
+ number of each deconv layer. Defaults to ``(256, 256, 256)``
+ deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.Defaults to
+ ``(4, 4, 4)``
+ conv_out_channels (Sequence[int], optional): The output channel number
+ of each intermediate conv layer. ``None`` means no intermediate
+ conv layer between deconv layers and the final conv layer.
+ Defaults to ``None``
+ conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
+ of each intermediate conv layer. Defaults to ``None``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config): Config of the keypoint loss. Defaults to use
+ :class:`KeypointMSELoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+ extra (dict, optional): Extra configurations.
+ Defaults to ``None``
+
+ .. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: ConfigType = dict(
+ type='KeypointMSELoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ if deconv_out_channels:
+ if deconv_kernel_sizes is None or len(deconv_out_channels) != len(
+ deconv_kernel_sizes):
+ raise ValueError(
+ '"deconv_out_channels" and "deconv_kernel_sizes" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {deconv_out_channels} and '
+ f'{deconv_kernel_sizes}')
+
+ self.deconv_layers = self._make_deconv_layers(
+ in_channels=in_channels,
+ layer_out_channels=deconv_out_channels,
+ layer_kernel_sizes=deconv_kernel_sizes,
+ )
+ in_channels = deconv_out_channels[-1]
+ else:
+ self.deconv_layers = nn.Identity()
+
+ if conv_out_channels:
+ if conv_kernel_sizes is None or len(conv_out_channels) != len(
+ conv_kernel_sizes):
+ raise ValueError(
+ '"conv_out_channels" and "conv_kernel_sizes" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {conv_out_channels} and '
+ f'{conv_kernel_sizes}')
+
+ self.conv_layers = self._make_conv_layers(
+ in_channels=in_channels,
+ layer_out_channels=conv_out_channels,
+ layer_kernel_sizes=conv_kernel_sizes)
+ in_channels = conv_out_channels[-1]
+ else:
+ self.conv_layers = nn.Identity()
+
+ if final_layer is not None:
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1)
+ cfg.update(final_layer)
+ self.final_layer = build_conv_layer(cfg)
+ else:
+ self.final_layer = nn.Identity()
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def _make_conv_layers(self, in_channels: int,
+ layer_out_channels: Sequence[int],
+ layer_kernel_sizes: Sequence[int]) -> nn.Module:
+ """Create convolutional layers by given parameters."""
+
+ layers = []
+ for out_channels, kernel_size in zip(layer_out_channels,
+ layer_kernel_sizes):
+ padding = (kernel_size - 1) // 2
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+ layers.append(build_conv_layer(cfg))
+ layers.append(nn.BatchNorm2d(num_features=out_channels))
+ layers.append(nn.ReLU(inplace=True))
+ in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def _make_deconv_layers(self, in_channels: int,
+ layer_out_channels: Sequence[int],
+ layer_kernel_sizes: Sequence[int]) -> nn.Module:
+ """Create deconvolutional layers by given parameters."""
+
+ layers = []
+ for out_channels, kernel_size in zip(layer_out_channels,
+ layer_kernel_sizes):
+ if kernel_size == 4:
+ padding = 1
+ output_padding = 0
+ elif kernel_size == 3:
+ padding = 1
+ output_padding = 1
+ elif kernel_size == 2:
+ padding = 0
+ output_padding = 0
+ else:
+ raise ValueError(f'Unsupported kernel size {kernel_size} for'
+ 'deconvlutional layers in '
+ f'{self.__class__.__name__}')
+ cfg = dict(
+ type='deconv',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=False)
+ layers.append(build_upsample_layer(cfg))
+ layers.append(nn.BatchNorm2d(num_features=out_channels))
+ layers.append(nn.ReLU(inplace=True))
+ in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(
+ type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1)
+ ]
+ return init_cfg
+
+ def forward(self, feats: Tuple[Tensor]) -> Tensor:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the heatmap.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tensor: output heatmap.
+ """
+ x = feats[-1]
+
+ x = self.deconv_layers(x)
+ x = self.conv_layers(x)
+ x = self.final_layer(x)
+
+ return x
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+ _batch_heatmaps = self.forward(_feats)
+ _batch_heatmaps_flip = flip_heatmaps(
+ self.forward(_feats_flip),
+ flip_mode=test_cfg.get('flip_mode', 'heatmap'),
+ flip_indices=flip_indices,
+ shift_heatmap=test_cfg.get('shift_heatmap', False))
+ batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
+ else:
+ batch_heatmaps = self.forward(feats)
+
+ preds = self.decode(batch_heatmaps)
+
+ if test_cfg.get('output_heatmaps', False):
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ feats (Tuple[Tensor]): The multi-stage features
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ train_cfg (dict): The runtime config for training process.
+ Defaults to {}
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+ pred_fields = self.forward(feats)
+ gt_heatmaps = torch.stack(
+ [d.gt_fields.heatmaps for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_fields, gt_heatmaps, keypoint_weights)
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ if train_cfg.get('compute_acc', True):
+ _, avg_acc, _ = pose_pck_accuracy(
+ output=to_numpy(pred_fields),
+ target=to_numpy(gt_heatmaps),
+ mask=to_numpy(keypoint_weights) > 0)
+
+ acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert old-version state dict of
+ :class:`DeepposeRegressionHead` (before MMPose v1.0.0) to a
+ compatible format of :class:`RegressionHead`.
+
+ The hook will be automatically registered during initialization.
+ """
+ version = local_meta.get('version', None)
+ if version and version >= self._version:
+ return
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for _k in keys:
+ if not _k.startswith(prefix):
+ continue
+ v = state_dict.pop(_k)
+ k = _k[len(prefix):]
+ # In old version, "final_layer" includes both intermediate
+ # conv layers (new "conv_layers") and final conv layers (new
+ # "final_layer").
+ #
+ # If there is no intermediate conv layer, old "final_layer" will
+ # have keys like "final_layer.xxx", which should be still
+ # named "final_layer.xxx";
+ #
+ # If there are intermediate conv layers, old "final_layer" will
+ # have keys like "final_layer.n.xxx", where the weights of the last
+ # one should be renamed "final_layer.xxx", and others should be
+ # renamed "conv_layers.n.xxx"
+ k_parts = k.split('.')
+ if k_parts[0] == 'final_layer':
+ if len(k_parts) == 3:
+ assert isinstance(self.conv_layers, nn.Sequential)
+ idx = int(k_parts[1])
+ if idx < len(self.conv_layers):
+ # final_layer.n.xxx -> conv_layers.n.xxx
+ k_new = 'conv_layers.' + '.'.join(k_parts[1:])
+ else:
+ # final_layer.n.xxx -> final_layer.xxx
+ k_new = 'final_layer.' + k_parts[2]
+ else:
+ # final_layer.xxx remains final_layer.xxx
+ k_new = k
+ else:
+ k_new = k
+
+ state_dict[prefix + k_new] = v
diff --git a/mmpose/models/heads/heatmap_heads/mspn_head.py b/mmpose/models/heads/heatmap_heads/mspn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7cddf7988bfc57cae314ef944f44b4d0d7df09
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/mspn_head.py
@@ -0,0 +1,432 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from typing import List, Optional, Sequence, Union
+
+import torch
+from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, Linear,
+ build_activation_layer, build_norm_layer)
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import pose_pck_accuracy
+from mmpose.models.utils.tta import flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, MultiConfig, OptConfigType,
+ OptSampleList, Predictions)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+MSMUFeatures = Sequence[Sequence[Tensor]] # Multi-stage multi-unit features
+
+
+class PRM(nn.Module):
+ """Pose Refine Machine.
+
+ Please refer to "Learning Delicate Local Representations
+ for Multi-Person Pose Estimation" (ECCV 2020).
+
+ Args:
+ out_channels (int): Number of the output channels, equals to
+ the number of keypoints.
+ norm_cfg (Config): Config to construct the norm layer.
+ Defaults to ``dict(type='BN')``
+ """
+
+ def __init__(self,
+ out_channels: int,
+ norm_cfg: ConfigType = dict(type='BN')):
+ super().__init__()
+
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ self.out_channels = out_channels
+ self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
+ self.middle_path = nn.Sequential(
+ Linear(self.out_channels, self.out_channels),
+ build_norm_layer(dict(type='BN1d'), out_channels)[1],
+ build_activation_layer(dict(type='ReLU')),
+ Linear(self.out_channels, self.out_channels),
+ build_norm_layer(dict(type='BN1d'), out_channels)[1],
+ build_activation_layer(dict(type='ReLU')),
+ build_activation_layer(dict(type='Sigmoid')))
+
+ self.bottom_path = nn.Sequential(
+ ConvModule(
+ self.out_channels,
+ self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=norm_cfg,
+ inplace=False),
+ DepthwiseSeparableConvModule(
+ self.out_channels,
+ 1,
+ kernel_size=9,
+ stride=1,
+ padding=4,
+ norm_cfg=norm_cfg,
+ inplace=False), build_activation_layer(dict(type='Sigmoid')))
+ self.conv_bn_relu_prm_1 = ConvModule(
+ self.out_channels,
+ self.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ norm_cfg=norm_cfg,
+ inplace=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward the network. The input heatmaps will be refined.
+
+ Args:
+ x (Tensor): The input heatmaps.
+
+ Returns:
+ Tensor: output heatmaps.
+ """
+ out = self.conv_bn_relu_prm_1(x)
+ out_1 = out
+
+ out_2 = self.global_pooling(out_1)
+ out_2 = out_2.view(out_2.size(0), -1)
+ out_2 = self.middle_path(out_2)
+ out_2 = out_2.unsqueeze(2)
+ out_2 = out_2.unsqueeze(3)
+
+ out_3 = self.bottom_path(out_1)
+ out = out_1 * (1 + out_2 * out_3)
+
+ return out
+
+
+class PredictHeatmap(nn.Module):
+ """Predict the heatmap for an input feature.
+
+ Args:
+ unit_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ out_shape (tuple): Shape of the output heatmaps.
+ use_prm (bool): Whether to use pose refine machine. Default: False.
+ norm_cfg (Config): Config to construct the norm layer.
+ Defaults to ``dict(type='BN')``
+ """
+
+ def __init__(self,
+ unit_channels: int,
+ out_channels: int,
+ out_shape: tuple,
+ use_prm: bool = False,
+ norm_cfg: ConfigType = dict(type='BN')):
+
+ super().__init__()
+
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+ self.unit_channels = unit_channels
+ self.out_channels = out_channels
+ self.out_shape = out_shape
+ self.use_prm = use_prm
+ if use_prm:
+ self.prm = PRM(out_channels, norm_cfg=norm_cfg)
+ self.conv_layers = nn.Sequential(
+ ConvModule(
+ unit_channels,
+ unit_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=norm_cfg,
+ inplace=False),
+ ConvModule(
+ unit_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ inplace=False))
+
+ def forward(self, feature: Tensor) -> Tensor:
+ """Forward the network.
+
+ Args:
+ feature (Tensor): The input feature maps.
+
+ Returns:
+ Tensor: output heatmaps.
+ """
+ feature = self.conv_layers(feature)
+ output = nn.functional.interpolate(
+ feature, size=self.out_shape, mode='bilinear', align_corners=True)
+ if self.use_prm:
+ output = self.prm(output)
+ return output
+
+
+@MODELS.register_module()
+class MSPNHead(BaseHead):
+ """Multi-stage multi-unit heatmap head introduced in `Multi-Stage Pose
+ estimation Network (MSPN)`_ by Li et al (2019), and used by `Residual Steps
+ Networks (RSN)`_ by Cai et al (2020). The head consists of multiple stages
+ and each stage consists of multiple units. Each unit of each stage has some
+ conv layers.
+
+ Args:
+ num_stages (int): Number of stages.
+ num_units (int): Number of units in each stage.
+ out_shape (tuple): The output shape of the output heatmaps.
+ unit_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ out_shape (tuple): Shape of the output heatmaps.
+ use_prm (bool): Whether to use pose refine machine (PRM).
+ Defaults to ``False``.
+ norm_cfg (Config): Config to construct the norm layer.
+ Defaults to ``dict(type='BN')``
+ loss (Config | List[Config]): Config of the keypoint loss for
+ different stages and different units.
+ Defaults to use :class:`KeypointMSELoss`.
+ level_indices (Sequence[int]): The indices that specified the level
+ of target heatmaps.
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`MSPN`: https://arxiv.org/abs/1901.00148
+ .. _`RSN`: https://arxiv.org/abs/2003.04030
+ """
+ _version = 2
+
+ def __init__(self,
+ num_stages: int = 4,
+ num_units: int = 4,
+ out_shape: tuple = (64, 48),
+ unit_channels: int = 256,
+ out_channels: int = 17,
+ use_prm: bool = False,
+ norm_cfg: ConfigType = dict(type='BN'),
+ level_indices: Sequence[int] = [],
+ loss: MultiConfig = dict(
+ type='KeypointMSELoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+ super().__init__(init_cfg)
+
+ self.num_stages = num_stages
+ self.num_units = num_units
+ self.out_shape = out_shape
+ self.unit_channels = unit_channels
+ self.out_channels = out_channels
+ if len(level_indices) != num_stages * num_units:
+ raise ValueError(
+ f'The length of level_indices({len(level_indices)}) did not '
+ f'match `num_stages`({num_stages}) * `num_units`({num_units})')
+
+ self.level_indices = level_indices
+
+ if isinstance(loss, list) and len(loss) != num_stages * num_units:
+ raise ValueError(
+ f'The length of loss_module({len(loss)}) did not match '
+ f'`num_stages`({num_stages}) * `num_units`({num_units})')
+
+ if isinstance(loss, list):
+ if len(loss) != num_stages * num_units:
+ raise ValueError(
+ f'The length of loss_module({len(loss)}) did not match '
+ f'`num_stages`({num_stages}) * `num_units`({num_units})')
+ self.loss_module = nn.ModuleList(
+ MODELS.build(_loss) for _loss in loss)
+ else:
+ self.loss_module = MODELS.build(loss)
+
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # Protect mutable default arguments
+ norm_cfg = copy.deepcopy(norm_cfg)
+
+ self.predict_layers = nn.ModuleList([])
+ for i in range(self.num_stages):
+ for j in range(self.num_units):
+ self.predict_layers.append(
+ PredictHeatmap(
+ unit_channels,
+ out_channels,
+ out_shape,
+ use_prm,
+ norm_cfg=norm_cfg))
+
+ @property
+ def default_init_cfg(self):
+ """Default config for weight initialization."""
+ init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(type='Normal', layer='Linear', std=0.01),
+ dict(type='Constant', layer='BatchNorm2d', val=1),
+ ]
+ return init_cfg
+
+ def forward(self, feats: Sequence[Sequence[Tensor]]) -> List[Tensor]:
+ """Forward the network. The input is multi-stage multi-unit feature
+ maps and the output is a list of heatmaps from multiple stages.
+
+ Args:
+ feats (Sequence[Sequence[Tensor]]): Feature maps from multiple
+ stages and units.
+
+ Returns:
+ List[Tensor]: A list of output heatmaps from multiple stages
+ and units.
+ """
+ out = []
+ assert len(feats) == self.num_stages, (
+ f'The length of feature maps did not match the '
+ f'`num_stages` in {self.__class__.__name__}')
+ for feat in feats:
+ assert len(feat) == self.num_units, (
+ f'The length of feature maps did not match the '
+ f'`num_units` in {self.__class__.__name__}')
+ for f in feat:
+ assert f.shape[1] == self.unit_channels, (
+ f'The number of feature map channels did not match the '
+ f'`unit_channels` in {self.__class__.__name__}')
+
+ for i in range(self.num_stages):
+ for j in range(self.num_units):
+ y = self.predict_layers[i * self.num_units + j](feats[i][j])
+ out.append(y)
+ return out
+
+ def predict(self,
+ feats: Union[MSMUFeatures, List[MSMUFeatures]],
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {}) -> Predictions:
+ """Predict results from multi-stage feature maps.
+
+ Args:
+ feats (Sequence[Sequence[Tensor]]): Multi-stage multi-unit
+ features (or multiple MSMU features for TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance_labels`.
+ test_cfg (Config, optional): The testing/inference config
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+ # multi-stage multi-unit batch heatmaps
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ _feats, _feats_flip = feats
+ _batch_heatmaps = self.forward(_feats)[-1]
+ _batch_heatmaps_flip = flip_heatmaps(
+ self.forward(_feats_flip)[-1],
+ flip_mode=test_cfg.get('flip_mode', 'heatmap'),
+ flip_indices=flip_indices,
+ shift_heatmap=test_cfg.get('shift_heatmap', False))
+ batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
+ else:
+ msmu_batch_heatmaps = self.forward(feats)
+ batch_heatmaps = msmu_batch_heatmaps[-1]
+
+ preds = self.decode(batch_heatmaps)
+
+ if test_cfg.get('output_heatmaps', False):
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(self,
+ feats: MSMUFeatures,
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Note:
+ - batch_size: B
+ - num_output_heatmap_levels: L
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+ - num_instances: N (usually 1 in topdown heatmap heads)
+
+ Args:
+ feats (Sequence[Sequence[Tensor]]): Feature maps from multiple
+ stages and units
+ batch_data_samples (List[:obj:`PoseDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_instance_labels` and `gt_fields`.
+ train_cfg (Config, optional): The training config
+
+ Returns:
+ dict: A dictionary of loss components.
+ """
+ # multi-stage multi-unit predict heatmaps
+ msmu_pred_heatmaps = self.forward(feats)
+
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ]) # shape: [B*N, L, K]
+
+ # calculate losses over multiple stages and multiple units
+ losses = dict()
+ for i in range(self.num_stages * self.num_units):
+ if isinstance(self.loss_module, nn.ModuleList):
+ # use different loss_module over different stages and units
+ loss_func = self.loss_module[i]
+ else:
+ # use the same loss_module over different stages and units
+ loss_func = self.loss_module
+
+ # select `gt_heatmaps` and `keypoint_weights` for different level
+ # according to `self.level_indices` to calculate loss
+ gt_heatmaps = torch.stack([
+ d.gt_fields[self.level_indices[i]].heatmaps
+ for d in batch_data_samples
+ ])
+ loss_i = loss_func(msmu_pred_heatmaps[i], gt_heatmaps,
+ keypoint_weights[:, self.level_indices[i]])
+
+ if 'loss_kpt' not in losses:
+ losses['loss_kpt'] = loss_i
+ else:
+ losses['loss_kpt'] += loss_i
+
+ # calculate accuracy
+ _, avg_acc, _ = pose_pck_accuracy(
+ output=to_numpy(msmu_pred_heatmaps[-1]),
+ target=to_numpy(gt_heatmaps),
+ mask=to_numpy(keypoint_weights[:, -1]) > 0)
+
+ acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
diff --git a/mmpose/models/heads/heatmap_heads/vipnas_head.py b/mmpose/models/heads/heatmap_heads/vipnas_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..949ee95b096124a162f6d9719446fa80bd26a201
--- /dev/null
+++ b/mmpose/models/heads/heatmap_heads/vipnas_head.py
@@ -0,0 +1,179 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Union
+
+from mmcv.cnn import build_conv_layer, build_upsample_layer
+from torch import nn
+
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.typing import ConfigType, OptConfigType
+from .heatmap_head import HeatmapHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class ViPNASHead(HeatmapHead):
+ """ViPNAS heatmap head introduced in `ViPNAS`_ by Xu et al (2021). The head
+ is composed of a few deconvolutional layers followed by a convolutional
+ layer to generate heatmaps from low-resolution feature maps. Specifically,
+ different from the :class: `HeatmapHead` introduced by `Simple Baselines`_,
+ the group numbers in the deconvolutional layers are elastic and thus can be
+ optimized by neural architecture search (NAS).
+
+ Args:
+ in_channels (int | Sequence[int]): Number of channels in the input
+ feature map
+ out_channels (int): Number of channels in the output heatmap
+ deconv_out_channels (Sequence[int], optional): The output channel
+ number of each deconv layer. Defaults to ``(144, 144, 144)``
+ deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.Defaults to
+ ``(4, 4, 4)``
+ deconv_num_groups (Sequence[int], optional): The group number of each
+ deconv layer. Defaults to ``(16, 16, 16)``
+ conv_out_channels (Sequence[int], optional): The output channel number
+ of each intermediate conv layer. ``None`` means no intermediate
+ conv layer between deconv layers and the final conv layer.
+ Defaults to ``None``
+ conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
+ of each intermediate conv layer. Defaults to ``None``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config): Config of the keypoint loss. Defaults to use
+ :class:`KeypointMSELoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`ViPNAS`: https://arxiv.org/abs/2105.10154
+ .. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ out_channels: int,
+ deconv_out_channels: OptIntSeq = (144, 144, 144),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ deconv_num_groups: OptIntSeq = (16, 16, 16),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: ConfigType = dict(
+ type='KeypointMSELoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super(HeatmapHead, self).__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ if deconv_out_channels:
+ if deconv_kernel_sizes is None or len(deconv_out_channels) != len(
+ deconv_kernel_sizes):
+ raise ValueError(
+ '"deconv_out_channels" and "deconv_kernel_sizes" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {deconv_out_channels} and '
+ f'{deconv_kernel_sizes}')
+ if deconv_num_groups is None or len(deconv_out_channels) != len(
+ deconv_num_groups):
+ raise ValueError(
+ '"deconv_out_channels" and "deconv_num_groups" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {deconv_out_channels} and '
+ f'{deconv_num_groups}')
+
+ self.deconv_layers = self._make_deconv_layers(
+ in_channels=in_channels,
+ layer_out_channels=deconv_out_channels,
+ layer_kernel_sizes=deconv_kernel_sizes,
+ layer_groups=deconv_num_groups,
+ )
+ in_channels = deconv_out_channels[-1]
+ else:
+ self.deconv_layers = nn.Identity()
+
+ if conv_out_channels:
+ if conv_kernel_sizes is None or len(conv_out_channels) != len(
+ conv_kernel_sizes):
+ raise ValueError(
+ '"conv_out_channels" and "conv_kernel_sizes" should '
+ 'be integer sequences with the same length. Got '
+ f'mismatched lengths {conv_out_channels} and '
+ f'{conv_kernel_sizes}')
+
+ self.conv_layers = self._make_conv_layers(
+ in_channels=in_channels,
+ layer_out_channels=conv_out_channels,
+ layer_kernel_sizes=conv_kernel_sizes)
+ in_channels = conv_out_channels[-1]
+ else:
+ self.conv_layers = nn.Identity()
+
+ if final_layer is not None:
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1)
+ cfg.update(final_layer)
+ self.final_layer = build_conv_layer(cfg)
+ else:
+ self.final_layer = nn.Identity()
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def _make_deconv_layers(self, in_channels: int,
+ layer_out_channels: Sequence[int],
+ layer_kernel_sizes: Sequence[int],
+ layer_groups: Sequence[int]) -> nn.Module:
+ """Create deconvolutional layers by given parameters."""
+
+ layers = []
+ for out_channels, kernel_size, groups in zip(layer_out_channels,
+ layer_kernel_sizes,
+ layer_groups):
+ if kernel_size == 4:
+ padding = 1
+ output_padding = 0
+ elif kernel_size == 3:
+ padding = 1
+ output_padding = 1
+ elif kernel_size == 2:
+ padding = 0
+ output_padding = 0
+ else:
+ raise ValueError(f'Unsupported kernel size {kernel_size} for'
+ 'deconvlutional layers in '
+ f'{self.__class__.__name__}')
+ cfg = dict(
+ type='deconv',
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ groups=groups,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=False)
+ layers.append(build_upsample_layer(cfg))
+ layers.append(nn.BatchNorm2d(num_features=out_channels))
+ layers.append(nn.ReLU(inplace=True))
+ in_channels = out_channels
+
+ return nn.Sequential(*layers)
diff --git a/mmpose/models/heads/hybrid_heads/__init__.py b/mmpose/models/heads/hybrid_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d5a211c1c0e17a61968ca5a266a797587f8c83
--- /dev/null
+++ b/mmpose/models/heads/hybrid_heads/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dekr_head import DEKRHead
+
+__all__ = [
+ 'DEKRHead',
+]
diff --git a/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1669ef83fbfcf02d9e985d2e9d1b491119bde95
Binary files /dev/null and b/mmpose/models/heads/hybrid_heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc b/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7c3940457ac8d16f34aae702d9b405bb89c96f8
Binary files /dev/null and b/mmpose/models/heads/hybrid_heads/__pycache__/dekr_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/hybrid_heads/dekr_head.py b/mmpose/models/heads/hybrid_heads/dekr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..41f7cfc4ce9f7cbb061c18ba14a4847a67a07ffc
--- /dev/null
+++ b/mmpose/models/heads/hybrid_heads/dekr_head.py
@@ -0,0 +1,581 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence, Tuple, Union
+
+import torch
+from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmengine.model import BaseModule, ModuleDict, Sequential
+from mmengine.structures import InstanceData, PixelData
+from torch import Tensor
+
+from mmpose.evaluation.functional.nms import nearby_joints_nms
+from mmpose.models.utils.tta import flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, Features, InstanceList,
+ OptConfigType, OptSampleList, Predictions)
+from ...backbones.resnet import BasicBlock
+from ..base_head import BaseHead
+
+try:
+ from mmcv.ops import DeformConv2d
+ has_mmcv_full = True
+except (ImportError, ModuleNotFoundError):
+ has_mmcv_full = False
+
+
+class AdaptiveActivationBlock(BaseModule):
+ """Adaptive activation convolution block. "Bottom-up human pose estimation
+ via disentangled keypoint regression", CVPR'2021.
+
+ Args:
+ in_channels (int): Number of input channels
+ out_channels (int): Number of output channels
+ groups (int): Number of groups. Generally equal to the
+ number of joints.
+ norm_cfg (dict): Config for normalization layers.
+ act_cfg (dict): Config for activation layers.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ groups=1,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ init_cfg=None):
+ super(AdaptiveActivationBlock, self).__init__(init_cfg=init_cfg)
+
+ assert in_channels % groups == 0 and out_channels % groups == 0
+ self.groups = groups
+
+ regular_matrix = torch.tensor([[-1, -1, -1, 0, 0, 0, 1, 1, 1],
+ [-1, 0, 1, -1, 0, 1, -1, 0, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1]])
+ self.register_buffer('regular_matrix', regular_matrix.float())
+
+ self.transform_matrix_conv = build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=in_channels,
+ out_channels=6 * groups,
+ kernel_size=3,
+ padding=1,
+ groups=groups,
+ bias=True)
+
+ if has_mmcv_full:
+ self.adapt_conv = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ groups=groups,
+ deform_groups=groups)
+ else:
+ raise ImportError('Please install the full version of mmcv '
+ 'to use `DeformConv2d`.')
+
+ self.norm = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x):
+ B, _, H, W = x.size()
+ residual = x
+
+ affine_matrix = self.transform_matrix_conv(x)
+ affine_matrix = affine_matrix.permute(0, 2, 3, 1).contiguous()
+ affine_matrix = affine_matrix.view(B, H, W, self.groups, 2, 3)
+ offset = torch.matmul(affine_matrix, self.regular_matrix)
+ offset = offset.transpose(4, 5).reshape(B, H, W, self.groups * 18)
+ offset = offset.permute(0, 3, 1, 2).contiguous()
+
+ x = self.adapt_conv(x, offset)
+ x = self.norm(x)
+ x = self.act(x + residual)
+
+ return x
+
+
+class RescoreNet(BaseModule):
+ """Rescore net used to predict the OKS score of predicted pose. We use the
+ off-the-shelf rescore net pretrained by authors of DEKR.
+
+ Args:
+ in_channels (int): Input channels
+ norm_indexes (Tuple(int)): Indices of torso in skeleton
+ init_cfg (dict, optional): Initialization config dict
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ norm_indexes,
+ init_cfg=None,
+ ):
+ super(RescoreNet, self).__init__(init_cfg=init_cfg)
+
+ self.norm_indexes = norm_indexes
+
+ hidden = 256
+
+ self.l1 = torch.nn.Linear(in_channels, hidden, bias=True)
+ self.l2 = torch.nn.Linear(hidden, hidden, bias=True)
+ self.l3 = torch.nn.Linear(hidden, 1, bias=True)
+ self.relu = torch.nn.ReLU()
+
+ def make_feature(self, keypoints, keypoint_scores, skeleton):
+ """Combine original scores, joint distance and relative distance to
+ make feature.
+
+ Args:
+ keypoints (torch.Tensor): predicetd keypoints
+ keypoint_scores (torch.Tensor): predicetd keypoint scores
+ skeleton (list(list(int))): joint links
+
+ Returns:
+ torch.Tensor: feature for each instance
+ """
+ joint_1, joint_2 = zip(*skeleton)
+ num_link = len(skeleton)
+
+ joint_relate = (keypoints[:, joint_1] -
+ keypoints[:, joint_2])[:, :, :2]
+ joint_length = joint_relate.norm(dim=2)
+
+ # To use the torso distance to normalize
+ normalize = (joint_length[:, self.norm_indexes[0]] +
+ joint_length[:, self.norm_indexes[1]]) / 2
+ normalize = normalize.unsqueeze(1).expand(normalize.size(0), num_link)
+ normalize = normalize.clamp(min=1).contiguous()
+
+ joint_length = joint_length / normalize[:, :]
+ joint_relate = joint_relate / normalize.unsqueeze(-1)
+ joint_relate = joint_relate.flatten(1)
+
+ feature = torch.cat((joint_relate, joint_length, keypoint_scores),
+ dim=1).float()
+ return feature
+
+ def forward(self, keypoints, keypoint_scores, skeleton):
+ feature = self.make_feature(keypoints, keypoint_scores, skeleton)
+ x = self.relu(self.l1(feature))
+ x = self.relu(self.l2(x))
+ x = self.l3(x)
+ return x.squeeze(1)
+
+
+@MODELS.register_module()
+class DEKRHead(BaseHead):
+ """DisEntangled Keypoint Regression head introduced in `Bottom-up human
+ pose estimation via disentangled keypoint regression`_ by Geng et al
+ (2021). The head is composed of a heatmap branch and a displacement branch.
+
+ Args:
+ in_channels (int | Sequence[int]): Number of channels in the input
+ feature map
+ num_joints (int): Number of joints
+ num_heatmap_filters (int): Number of filters for heatmap branch.
+ Defaults to 32
+ num_offset_filters_per_joint (int): Number of filters for each joint
+ in displacement branch. Defaults to 15
+ heatmap_loss (Config): Config of the heatmap loss. Defaults to use
+ :class:`KeypointMSELoss`
+ displacement_loss (Config): Config of the displacement regression loss.
+ Defaults to use :class:`SoftWeightSmoothL1Loss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ rescore_cfg (Config, optional): The config for rescore net which
+ estimates OKS via predicted keypoints and keypoint scores.
+ Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`Bottom-up human pose estimation via disentangled keypoint regression`:
+ https://arxiv.org/abs/2104.02300
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ num_keypoints: int,
+ num_heatmap_filters: int = 32,
+ num_displacement_filters_per_keypoint: int = 15,
+ heatmap_loss: ConfigType = dict(
+ type='KeypointMSELoss', use_target_weight=True),
+ displacement_loss: ConfigType = dict(
+ type='SoftWeightSmoothL1Loss',
+ use_target_weight=True,
+ supervise_empty=False),
+ decoder: OptConfigType = None,
+ rescore_cfg: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.num_keypoints = num_keypoints
+
+ # build heatmap branch
+ self.heatmap_conv_layers = self._make_heatmap_conv_layers(
+ in_channels=in_channels,
+ out_channels=1 + num_keypoints,
+ num_filters=num_heatmap_filters,
+ )
+
+ # build displacement branch
+ self.displacement_conv_layers = self._make_displacement_conv_layers(
+ in_channels=in_channels,
+ out_channels=2 * num_keypoints,
+ num_filters=num_keypoints * num_displacement_filters_per_keypoint,
+ groups=num_keypoints)
+
+ # build losses
+ self.loss_module = ModuleDict(
+ dict(
+ heatmap=MODELS.build(heatmap_loss),
+ displacement=MODELS.build(displacement_loss),
+ ))
+
+ # build decoder
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # build rescore net
+ if rescore_cfg is not None:
+ self.rescore_net = RescoreNet(**rescore_cfg)
+ else:
+ self.rescore_net = None
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [
+ dict(
+ type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001),
+ dict(type='Constant', layer='BatchNorm2d', val=1)
+ ]
+ return init_cfg
+
+ def _make_heatmap_conv_layers(self, in_channels: int, out_channels: int,
+ num_filters: int):
+ """Create convolutional layers of heatmap branch by given
+ parameters."""
+ layers = [
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ norm_cfg=dict(type='BN')),
+ BasicBlock(num_filters, num_filters),
+ build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=num_filters,
+ out_channels=out_channels,
+ kernel_size=1),
+ ]
+
+ return Sequential(*layers)
+
+ def _make_displacement_conv_layers(self, in_channels: int,
+ out_channels: int, num_filters: int,
+ groups: int):
+ """Create convolutional layers of displacement branch by given
+ parameters."""
+ layers = [
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=num_filters,
+ kernel_size=1,
+ norm_cfg=dict(type='BN')),
+ AdaptiveActivationBlock(num_filters, num_filters, groups=groups),
+ AdaptiveActivationBlock(num_filters, num_filters, groups=groups),
+ build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=num_filters,
+ out_channels=out_channels,
+ kernel_size=1,
+ groups=groups)
+ ]
+
+ return Sequential(*layers)
+
+ def forward(self, feats: Tuple[Tensor]) -> Tensor:
+ """Forward the network. The input is multi scale feature maps and the
+ output is a tuple of heatmap and displacement.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tuple[Tensor]: output heatmap and displacement.
+ """
+ x = feats[-1]
+
+ heatmaps = self.heatmap_conv_layers(x)
+ displacements = self.displacement_conv_layers(x)
+
+ return heatmaps, displacements
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ feats (Tuple[Tensor]): The multi-stage features
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ train_cfg (dict): The runtime config for training process.
+ Defaults to {}
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+ pred_heatmaps, pred_displacements = self.forward(feats)
+ gt_heatmaps = torch.stack(
+ [d.gt_fields.heatmaps for d in batch_data_samples])
+ heatmap_weights = torch.stack(
+ [d.gt_fields.heatmap_weights for d in batch_data_samples])
+ gt_displacements = torch.stack(
+ [d.gt_fields.displacements for d in batch_data_samples])
+ displacement_weights = torch.stack(
+ [d.gt_fields.displacement_weights for d in batch_data_samples])
+
+ if 'heatmap_mask' in batch_data_samples[0].gt_fields.keys():
+ heatmap_mask = torch.stack(
+ [d.gt_fields.heatmap_mask for d in batch_data_samples])
+ else:
+ heatmap_mask = None
+
+ # calculate losses
+ losses = dict()
+ heatmap_loss = self.loss_module['heatmap'](pred_heatmaps, gt_heatmaps,
+ heatmap_weights,
+ heatmap_mask)
+ displacement_loss = self.loss_module['displacement'](
+ pred_displacements, gt_displacements, displacement_weights)
+
+ losses.update({
+ 'loss/heatmap': heatmap_loss,
+ 'loss/displacement': displacement_loss,
+ })
+
+ return losses
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-scale features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (1, h, w)
+ or (K+1, h, w) if keypoint heatmaps are predicted
+ - displacements (Tensor): The predicted displacement fields
+ in shape (K*2, h, w)
+ """
+
+ assert len(batch_data_samples) == 1, f'DEKRHead only supports ' \
+ f'prediction with batch_size 1, but got {len(batch_data_samples)}'
+
+ multiscale_test = test_cfg.get('multiscale_test', False)
+ flip_test = test_cfg.get('flip_test', False)
+ metainfo = batch_data_samples[0].metainfo
+ aug_scales = [1]
+
+ if not multiscale_test:
+ feats = [feats]
+ else:
+ aug_scales = aug_scales + metainfo['aug_scales']
+
+ heatmaps, displacements = [], []
+ for feat, s in zip(feats, aug_scales):
+ if flip_test:
+ assert isinstance(feat, list) and len(feat) == 2
+ flip_indices = metainfo['flip_indices']
+ _feat, _feat_flip = feat
+ _heatmaps, _displacements = self.forward(_feat)
+ _heatmaps_flip, _displacements_flip = self.forward(_feat_flip)
+
+ _heatmaps_flip = flip_heatmaps(
+ _heatmaps_flip,
+ flip_mode='heatmap',
+ flip_indices=flip_indices + [len(flip_indices)],
+ shift_heatmap=test_cfg.get('shift_heatmap', False))
+ _heatmaps = (_heatmaps + _heatmaps_flip) / 2.0
+
+ _displacements_flip = flip_heatmaps(
+ _displacements_flip,
+ flip_mode='offset',
+ flip_indices=flip_indices,
+ shift_heatmap=False)
+
+ # this is a coordinate amendment.
+ x_scale_factor = s * (
+ metainfo['input_size'][0] / _heatmaps.shape[-1])
+ _displacements_flip[:, ::2] += (x_scale_factor - 1) / (
+ x_scale_factor)
+ _displacements = (_displacements + _displacements_flip) / 2.0
+
+ else:
+ _heatmaps, _displacements = self.forward(feat)
+
+ heatmaps.append(_heatmaps)
+ displacements.append(_displacements)
+
+ preds = self.decode(heatmaps, displacements, test_cfg, metainfo)
+
+ if test_cfg.get('output_heatmaps', False):
+ heatmaps = [hm.detach() for hm in heatmaps]
+ displacements = [dm.detach() for dm in displacements]
+ B = heatmaps[0].shape[0]
+ pred_fields = []
+ for i in range(B):
+ pred_fields.append(
+ PixelData(
+ heatmaps=heatmaps[0][i],
+ displacements=displacements[0][i]))
+ return preds, pred_fields
+ else:
+ return preds
+
+ def decode(self,
+ heatmaps: Tuple[Tensor],
+ displacements: Tuple[Tensor],
+ test_cfg: ConfigType = {},
+ metainfo: dict = {}) -> InstanceList:
+ """Decode keypoints from outputs.
+
+ Args:
+ heatmaps (Tuple[Tensor]): The output heatmaps inferred from one
+ image or multi-scale images.
+ displacements (Tuple[Tensor]): The output displacement fields
+ inferred from one image or multi-scale images.
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+ metainfo (dict): The metainfo of test dataset. Defaults to {}
+
+ Returns:
+ List[InstanceData]: A list of InstanceData, each contains the
+ decoded pose information of the instances of one data sample.
+ """
+
+ if self.decoder is None:
+ raise RuntimeError(
+ f'The decoder has not been set in {self.__class__.__name__}. '
+ 'Please set the decoder configs in the init parameters to '
+ 'enable head methods `head.predict()` and `head.decode()`')
+
+ multiscale_test = test_cfg.get('multiscale_test', False)
+ skeleton = metainfo.get('skeleton_links', None)
+
+ preds = []
+ batch_size = heatmaps[0].shape[0]
+
+ for b in range(batch_size):
+ if multiscale_test:
+ raise NotImplementedError
+ else:
+ keypoints, (root_scores,
+ keypoint_scores) = self.decoder.decode(
+ heatmaps[0][b], displacements[0][b])
+
+ # rescore each instance
+ if self.rescore_net is not None and skeleton and len(
+ keypoints) > 0:
+ instance_scores = self.rescore_net(keypoints, keypoint_scores,
+ skeleton)
+ instance_scores[torch.isnan(instance_scores)] = 0
+ root_scores = root_scores * instance_scores
+
+ # nms
+ keypoints, keypoint_scores = to_numpy((keypoints, keypoint_scores))
+ scores = to_numpy(root_scores)[..., None] * keypoint_scores
+ if len(keypoints) > 0 and test_cfg.get('nms_dist_thr', 0) > 0:
+ kpts_db = []
+ for i in range(len(keypoints)):
+ kpts_db.append(
+ dict(keypoints=keypoints[i], score=keypoint_scores[i]))
+ keep_instance_inds = nearby_joints_nms(
+ kpts_db,
+ test_cfg['nms_dist_thr'],
+ test_cfg.get('nms_joints_thr', None),
+ score_per_joint=True,
+ max_dets=test_cfg.get('max_num_people', 30))
+ keypoints = keypoints[keep_instance_inds]
+ scores = scores[keep_instance_inds]
+
+ # pack outputs
+ preds.append(
+ InstanceData(keypoints=keypoints, keypoint_scores=scores))
+
+ return preds
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert old-version state dict of
+ :class:`DEKRHead` (before MMPose v1.0.0) to a compatible format
+ of :class:`DEKRHead`.
+
+ The hook will be automatically registered during initialization.
+ """
+ version = local_meta.get('version', None)
+ if version and version >= self._version:
+ return
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for k in keys:
+ if 'offset_conv_layer' in k:
+ v = state_dict.pop(k)
+ k = k.replace('offset_conv_layers', 'displacement_conv_layers')
+ if 'displacement_conv_layers.3.' in k:
+ # the source and target of displacement vectors are
+ # opposite between two versions.
+ v = -v
+ state_dict[k] = v
+
+ if 'heatmap_conv_layers.2' in k:
+ # root heatmap is at the first/last channel of the
+ # heatmap tensor in MMPose v0.x/1.x, respectively.
+ v = state_dict.pop(k)
+ state_dict[k] = torch.cat((v[1:], v[:1]))
+
+ if 'rescore_net' in k:
+ v = state_dict.pop(k)
+ k = k.replace('rescore_net', 'head.rescore_net')
+ state_dict[k] = v
diff --git a/mmpose/models/heads/regression_heads/__init__.py b/mmpose/models/heads/regression_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2a5027b1b6f33f488c268c3bde0142681f4ac4c
--- /dev/null
+++ b/mmpose/models/heads/regression_heads/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dsnt_head import DSNTHead
+from .integral_regression_head import IntegralRegressionHead
+from .regression_head import RegressionHead
+from .rle_head import RLEHead
+
+__all__ = [
+ 'RegressionHead',
+ 'IntegralRegressionHead',
+ 'DSNTHead',
+ 'RLEHead',
+]
diff --git a/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c04fcd7346069d3e36bccf6e1c0720658b8408df
Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e22015c60eb75b8cc8bd4e65ede26098d270fa41
Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/dsnt_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..088d83163c4b8d9d7d0b482304933bc11a45d4ed
Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/integral_regression_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6635b75277ea58e5011239676edd866eab9adc32
Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/regression_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc b/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31cd606ae22b3947142260f37a648e08de0e2684
Binary files /dev/null and b/mmpose/models/heads/regression_heads/__pycache__/rle_head.cpython-38.pyc differ
diff --git a/mmpose/models/heads/regression_heads/dsnt_head.py b/mmpose/models/heads/regression_heads/dsnt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bd49e385db31c996de086419285e2f5fa7748b3
--- /dev/null
+++ b/mmpose/models/heads/regression_heads/dsnt_head.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from mmengine.logging import MessageHub
+from torch import Tensor
+
+from mmpose.evaluation.functional import keypoint_pck_accuracy
+from mmpose.registry import MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import ConfigType, OptConfigType, OptSampleList
+from .integral_regression_head import IntegralRegressionHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class DSNTHead(IntegralRegressionHead):
+ """Top-down integral regression head introduced in `DSNT`_ by Nibali et
+ al(2018). The head contains a differentiable spatial to numerical transform
+ (DSNT) layer that do soft-argmax operation on the predicted heatmaps to
+ regress the coordinates.
+
+ This head is used for algorithms that require supervision of heatmaps
+ in `DSNT` approach.
+
+ Args:
+ in_channels (int | sequence[int]): Number of input channels
+ in_featuremap_size (int | sequence[int]): Size of input feature map
+ num_joints (int): Number of joints
+ lambda_t (int): Discard heatmap-based loss when current
+ epoch > lambda_t. Defaults to -1.
+ debias (bool): Whether to remove the bias of Integral Pose Regression.
+ see `Removing the Bias of Integral Pose Regression`_ by Gu et al
+ (2021). Defaults to ``False``.
+ beta (float): A smoothing parameter in softmax. Defaults to ``1.0``.
+ deconv_out_channels (sequence[int]): The output channel number of each
+ deconv layer. Defaults to ``(256, 256, 256)``
+ deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.Defaults to
+ ``(4, 4, 4)``
+ conv_out_channels (sequence[int], optional): The output channel number
+ of each intermediate conv layer. ``None`` means no intermediate
+ conv layer between deconv layers and the final conv layer.
+ Defaults to ``None``
+ conv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each intermediate conv layer. Defaults to ``None``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config): Config for keypoint loss. Defaults to use
+ :class:`DSNTLoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`DSNT`: https://arxiv.org/abs/1801.07372
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ in_featuremap_size: Tuple[int, int],
+ num_joints: int,
+ lambda_t: int = -1,
+ debias: bool = False,
+ beta: float = 1.0,
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: ConfigType = dict(
+ type='MultipleLossWrapper',
+ losses=[
+ dict(type='SmoothL1Loss', use_target_weight=True),
+ dict(type='JSDiscretLoss', use_target_weight=True)
+ ]),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ super().__init__(
+ in_channels=in_channels,
+ in_featuremap_size=in_featuremap_size,
+ num_joints=num_joints,
+ debias=debias,
+ beta=beta,
+ deconv_out_channels=deconv_out_channels,
+ deconv_kernel_sizes=deconv_kernel_sizes,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer,
+ loss=loss,
+ decoder=decoder,
+ init_cfg=init_cfg)
+
+ self.lambda_t = lambda_t
+
+ def loss(self,
+ inputs: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_coords, pred_heatmaps = self.forward(inputs)
+ keypoint_labels = torch.cat(
+ [d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+ gt_heatmaps = torch.stack(
+ [d.gt_fields.heatmaps for d in batch_data_samples])
+
+ input_list = [pred_coords, pred_heatmaps]
+ target_list = [keypoint_labels, gt_heatmaps]
+ # calculate losses
+ losses = dict()
+
+ loss_list = self.loss_module(input_list, target_list, keypoint_weights)
+
+ loss = loss_list[0] + loss_list[1]
+
+ if self.lambda_t > 0:
+ mh = MessageHub.get_current_instance()
+ cur_epoch = mh.get_info('epoch')
+ if cur_epoch >= self.lambda_t:
+ loss = loss_list[0]
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = keypoint_pck_accuracy(
+ pred=to_numpy(pred_coords),
+ gt=to_numpy(keypoint_labels),
+ mask=to_numpy(keypoint_weights) > 0,
+ thr=0.05,
+ norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32))
+
+ acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
diff --git a/mmpose/models/heads/regression_heads/integral_regression_head.py b/mmpose/models/heads/regression_heads/integral_regression_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9046d94ad4318a19a3037f839ee054a445c80c68
--- /dev/null
+++ b/mmpose/models/heads/regression_heads/integral_regression_head.py
@@ -0,0 +1,339 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import build_conv_layer
+from mmengine.structures import PixelData
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import keypoint_pck_accuracy
+from mmpose.models.utils.tta import flip_coordinates, flip_heatmaps
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
+ Predictions)
+from .. import HeatmapHead
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class IntegralRegressionHead(BaseHead):
+ """Top-down integral regression head introduced in `IPR`_ by Xiao et
+ al(2018). The head contains a differentiable spatial to numerical transform
+ (DSNT) layer that do soft-argmax operation on the predicted heatmaps to
+ regress the coordinates.
+
+ This head is used for algorithms that only supervise the coordinates.
+
+ Args:
+ in_channels (int | sequence[int]): Number of input channels
+ in_featuremap_size (int | sequence[int]): Size of input feature map
+ num_joints (int): Number of joints
+ debias (bool): Whether to remove the bias of Integral Pose Regression.
+ see `Removing the Bias of Integral Pose Regression`_ by Gu et al
+ (2021). Defaults to ``False``.
+ beta (float): A smoothing parameter in softmax. Defaults to ``1.0``.
+ deconv_out_channels (sequence[int]): The output channel number of each
+ deconv layer. Defaults to ``(256, 256, 256)``
+ deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each deconv layer. Each element should be either an integer for
+ both height and width dimensions, or a tuple of two integers for
+ the height and the width dimension respectively.Defaults to
+ ``(4, 4, 4)``
+ conv_out_channels (sequence[int], optional): The output channel number
+ of each intermediate conv layer. ``None`` means no intermediate
+ conv layer between deconv layers and the final conv layer.
+ Defaults to ``None``
+ conv_kernel_sizes (sequence[int | tuple], optional): The kernel size
+ of each intermediate conv layer. Defaults to ``None``
+ final_layer (dict): Arguments of the final Conv2d layer.
+ Defaults to ``dict(kernel_size=1)``
+ loss (Config): Config for keypoint loss. Defaults to use
+ :class:`SmoothL1Loss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`IPR`: https://arxiv.org/abs/1711.08229
+ .. _`Debias`:
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ in_featuremap_size: Tuple[int, int],
+ num_joints: int,
+ debias: bool = False,
+ beta: float = 1.0,
+ deconv_out_channels: OptIntSeq = (256, 256, 256),
+ deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
+ conv_out_channels: OptIntSeq = None,
+ conv_kernel_sizes: OptIntSeq = None,
+ final_layer: dict = dict(kernel_size=1),
+ loss: ConfigType = dict(
+ type='SmoothL1Loss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.num_joints = num_joints
+ self.debias = debias
+ self.beta = beta
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ num_deconv = len(deconv_out_channels) if deconv_out_channels else 0
+ if num_deconv != 0:
+
+ self.heatmap_size = tuple(
+ [s * (2**num_deconv) for s in in_featuremap_size])
+
+ # deconv layers + 1x1 conv
+ self.simplebaseline_head = HeatmapHead(
+ in_channels=in_channels,
+ out_channels=num_joints,
+ deconv_out_channels=deconv_out_channels,
+ deconv_kernel_sizes=deconv_kernel_sizes,
+ conv_out_channels=conv_out_channels,
+ conv_kernel_sizes=conv_kernel_sizes,
+ final_layer=final_layer)
+
+ if final_layer is not None:
+ in_channels = num_joints
+ else:
+ in_channels = deconv_out_channels[-1]
+
+ else:
+ self.simplebaseline_head = None
+
+ if final_layer is not None:
+ cfg = dict(
+ type='Conv2d',
+ in_channels=in_channels,
+ out_channels=num_joints,
+ kernel_size=1)
+ cfg.update(final_layer)
+ self.final_layer = build_conv_layer(cfg)
+ else:
+ self.final_layer = None
+
+ self.heatmap_size = in_featuremap_size
+
+ if isinstance(in_channels, list):
+ raise ValueError(
+ f'{self.__class__.__name__} does not support selecting '
+ 'multiple input features.')
+
+ W, H = self.heatmap_size
+ self.linspace_x = torch.arange(0.0, 1.0 * W, 1).reshape(1, 1, 1, W) / W
+ self.linspace_y = torch.arange(0.0, 1.0 * H, 1).reshape(1, 1, H, 1) / H
+
+ self.linspace_x = nn.Parameter(self.linspace_x, requires_grad=False)
+ self.linspace_y = nn.Parameter(self.linspace_y, requires_grad=False)
+
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def _linear_expectation(self, heatmaps: Tensor,
+ linspace: Tensor) -> Tensor:
+ """Calculate linear expectation."""
+
+ B, N, _, _ = heatmaps.shape
+ heatmaps = heatmaps.mul(linspace).reshape(B, N, -1)
+ expectation = torch.sum(heatmaps, dim=2, keepdim=True)
+
+ return expectation
+
+ def _flat_softmax(self, featmaps: Tensor) -> Tensor:
+ """Use Softmax to normalize the featmaps in depthwise."""
+
+ _, N, H, W = featmaps.shape
+
+ featmaps = featmaps.reshape(-1, N, H * W)
+ heatmaps = F.softmax(featmaps, dim=2)
+
+ return heatmaps.reshape(-1, N, H, W)
+
+ def forward(self, feats: Tuple[Tensor]) -> Union[Tensor, Tuple[Tensor]]:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the coordinates.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tensor: output coordinates(and sigmas[optional]).
+ """
+ if self.simplebaseline_head is None:
+ feats = feats[-1]
+ if self.final_layer is not None:
+ feats = self.final_layer(feats)
+ else:
+ feats = self.simplebaseline_head(feats)
+
+ heatmaps = self._flat_softmax(feats * self.beta)
+
+ pred_x = self._linear_expectation(heatmaps, self.linspace_x)
+ pred_y = self._linear_expectation(heatmaps, self.linspace_y)
+
+ if self.debias:
+ B, N, H, W = feats.shape
+ C = feats.reshape(B, N, H * W).exp().sum(dim=2).reshape(B, N, 1)
+ pred_x = C / (C - 1) * (pred_x - 1 / (2 * C))
+ pred_y = C / (C - 1) * (pred_y - 1 / (2 * C))
+
+ coords = torch.cat([pred_x, pred_y], dim=-1)
+ return coords, heatmaps
+
+ def predict(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features.
+
+ Args:
+ feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage
+ features (or multiple multi-stage features in TTA)
+ batch_data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+ test_cfg (dict): The runtime config for testing process. Defaults
+ to {}
+
+ Returns:
+ Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If
+ ``test_cfg['output_heatmap']==True``, return both pose and heatmap
+ prediction; otherwise only return the pose prediction.
+
+ The pose prediction is a list of ``InstanceData``, each contains
+ the following fields:
+
+ - keypoints (np.ndarray): predicted keypoint coordinates in
+ shape (num_instances, K, D) where K is the keypoint number
+ and D is the keypoint dimension
+ - keypoint_scores (np.ndarray): predicted keypoint scores in
+ shape (num_instances, K)
+
+ The heatmap prediction is a list of ``PixelData``, each contains
+ the following fields:
+
+ - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w)
+ """
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ input_size = batch_data_samples[0].metainfo['input_size']
+ _feats, _feats_flip = feats
+
+ _batch_coords, _batch_heatmaps = self.forward(_feats)
+
+ _batch_coords_flip, _batch_heatmaps_flip = self.forward(
+ _feats_flip)
+ _batch_coords_flip = flip_coordinates(
+ _batch_coords_flip,
+ flip_indices=flip_indices,
+ shift_coords=test_cfg.get('shift_coords', True),
+ input_size=input_size)
+ _batch_heatmaps_flip = flip_heatmaps(
+ _batch_heatmaps_flip,
+ flip_mode='heatmap',
+ flip_indices=flip_indices,
+ shift_heatmap=test_cfg.get('shift_heatmap', False))
+
+ batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
+ batch_heatmaps = (_batch_heatmaps + _batch_heatmaps_flip) * 0.5
+ else:
+ batch_coords, batch_heatmaps = self.forward(feats) # (B, K, D)
+
+ batch_coords.unsqueeze_(dim=1) # (B, N, K, D)
+ preds = self.decode(batch_coords)
+
+ if test_cfg.get('output_heatmaps', False):
+ pred_fields = [
+ PixelData(heatmaps=hm) for hm in batch_heatmaps.detach()
+ ]
+ return preds, pred_fields
+ else:
+ return preds
+
+ def loss(self,
+ inputs: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_coords, _ = self.forward(inputs)
+ keypoint_labels = torch.cat(
+ [d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+
+ # calculate losses
+ losses = dict()
+
+ # TODO: multi-loss calculation
+ loss = self.loss_module(pred_coords, keypoint_labels, keypoint_weights)
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = keypoint_pck_accuracy(
+ pred=to_numpy(pred_coords),
+ gt=to_numpy(keypoint_labels),
+ mask=to_numpy(keypoint_weights) > 0,
+ thr=0.05,
+ norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32))
+
+ acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
+ return init_cfg
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to load weights of deconv layers from
+ :class:`HeatmapHead` into `simplebaseline_head`.
+
+ The hook will be automatically registered during initialization.
+ """
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for _k in keys:
+ if not _k.startswith(prefix):
+ continue
+ v = state_dict.pop(_k)
+ k = _k.lstrip(prefix)
+
+ k_new = _k
+ k_parts = k.split('.')
+ if self.simplebaseline_head is not None:
+ if k_parts[0] == 'conv_layers':
+ k_new = (
+ prefix + 'simplebaseline_head.deconv_layers.' +
+ '.'.join(k_parts[1:]))
+ elif k_parts[0] == 'final_layer':
+ k_new = prefix + 'simplebaseline_head.' + k
+
+ state_dict[k_new] = v
diff --git a/mmpose/models/heads/regression_heads/regression_head.py b/mmpose/models/heads/regression_heads/regression_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ff73aa6ef1bed93e8985d9be20f3c94355d8c21
--- /dev/null
+++ b/mmpose/models/heads/regression_heads/regression_head.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import keypoint_pck_accuracy
+from mmpose.models.utils.tta import flip_coordinates
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
+ Predictions)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class RegressionHead(BaseHead):
+ """Top-down regression head introduced in `Deeppose`_ by Toshev et al
+ (2014). The head is composed of fully-connected layers to predict the
+ coordinates directly.
+
+ Args:
+ in_channels (int | sequence[int]): Number of input channels
+ num_joints (int): Number of joints
+ loss (Config): Config for keypoint loss. Defaults to use
+ :class:`SmoothL1Loss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`Deeppose`: https://arxiv.org/abs/1312.4659
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ num_joints: int,
+ loss: ConfigType = dict(
+ type='SmoothL1Loss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.num_joints = num_joints
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # Define fully-connected layers
+ self.fc = nn.Linear(in_channels, self.num_joints * 2)
+
+ def forward(self, feats: Tuple[Tensor]) -> Tensor:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the coordinates.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tensor: output coordinates(and sigmas[optional]).
+ """
+ x = feats[-1]
+
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x.reshape(-1, self.num_joints, 2)
+
+ def predict(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from outputs."""
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ input_size = batch_data_samples[0].metainfo['input_size']
+ _feats, _feats_flip = feats
+
+ _batch_coords = self.forward(_feats)
+ _batch_coords_flip = flip_coordinates(
+ self.forward(_feats_flip),
+ flip_indices=flip_indices,
+ shift_coords=test_cfg.get('shift_coords', True),
+ input_size=input_size)
+ batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
+ else:
+ batch_coords = self.forward(feats) # (B, K, D)
+
+ batch_coords.unsqueeze_(dim=1) # (B, N, K, D)
+ preds = self.decode(batch_coords)
+
+ return preds
+
+ def loss(self,
+ inputs: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_outputs = self.forward(inputs)
+
+ keypoint_labels = torch.cat(
+ [d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_outputs, keypoint_labels,
+ keypoint_weights.unsqueeze(-1))
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = keypoint_pck_accuracy(
+ pred=to_numpy(pred_outputs),
+ gt=to_numpy(keypoint_labels),
+ mask=to_numpy(keypoint_weights) > 0,
+ thr=0.05,
+ norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32))
+
+ acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
+ return init_cfg
diff --git a/mmpose/models/heads/regression_heads/rle_head.py b/mmpose/models/heads/regression_heads/rle_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef62d7d9acbe1235c47d8acbef47e38ab1be6348
--- /dev/null
+++ b/mmpose/models/heads/regression_heads/rle_head.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from mmpose.evaluation.functional import keypoint_pck_accuracy
+from mmpose.models.utils.tta import flip_coordinates
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
+ Predictions)
+from ..base_head import BaseHead
+
+OptIntSeq = Optional[Sequence[int]]
+
+
+@MODELS.register_module()
+class RLEHead(BaseHead):
+ """Top-down regression head introduced in `RLE`_ by Li et al(2021). The
+ head is composed of fully-connected layers to predict the coordinates and
+ sigma(the variance of the coordinates) together.
+
+ Args:
+ in_channels (int | sequence[int]): Number of input channels
+ num_joints (int): Number of joints
+ loss (Config): Config for keypoint loss. Defaults to use
+ :class:`RLELoss`
+ decoder (Config, optional): The decoder config that controls decoding
+ keypoint coordinates from the network output. Defaults to ``None``
+ init_cfg (Config, optional): Config to control the initialization. See
+ :attr:`default_init_cfg` for default settings
+
+ .. _`RLE`: https://arxiv.org/abs/2107.11291
+ """
+
+ _version = 2
+
+ def __init__(self,
+ in_channels: Union[int, Sequence[int]],
+ num_joints: int,
+ loss: ConfigType = dict(
+ type='RLELoss', use_target_weight=True),
+ decoder: OptConfigType = None,
+ init_cfg: OptConfigType = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.num_joints = num_joints
+ self.loss_module = MODELS.build(loss)
+ if decoder is not None:
+ self.decoder = KEYPOINT_CODECS.build(decoder)
+ else:
+ self.decoder = None
+
+ # Define fully-connected layers
+ self.fc = nn.Linear(in_channels, self.num_joints * 4)
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def forward(self, feats: Tuple[Tensor]) -> Tensor:
+ """Forward the network. The input is multi scale feature maps and the
+ output is the coordinates.
+
+ Args:
+ feats (Tuple[Tensor]): Multi scale feature maps.
+
+ Returns:
+ Tensor: output coordinates(and sigmas[optional]).
+ """
+ x = feats[-1]
+
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x.reshape(-1, self.num_joints, 4)
+
+ def predict(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from outputs."""
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ input_size = batch_data_samples[0].metainfo['input_size']
+
+ _feats, _feats_flip = feats
+
+ _batch_coords = self.forward(_feats)
+ _batch_coords[..., 2:] = _batch_coords[..., 2:].sigmoid()
+
+ _batch_coords_flip = flip_coordinates(
+ self.forward(_feats_flip),
+ flip_indices=flip_indices,
+ shift_coords=test_cfg.get('shift_coords', True),
+ input_size=input_size)
+ _batch_coords_flip[..., 2:] = _batch_coords_flip[..., 2:].sigmoid()
+
+ batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
+ else:
+ batch_coords = self.forward(feats) # (B, K, D)
+ batch_coords[..., 2:] = batch_coords[..., 2:].sigmoid()
+
+ batch_coords.unsqueeze_(dim=1) # (B, N, K, D)
+ preds = self.decode(batch_coords)
+
+ return preds
+
+ def loss(self,
+ inputs: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: ConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ pred_outputs = self.forward(inputs)
+
+ keypoint_labels = torch.cat(
+ [d.gt_instance_labels.keypoint_labels for d in batch_data_samples])
+ keypoint_weights = torch.cat([
+ d.gt_instance_labels.keypoint_weights for d in batch_data_samples
+ ])
+
+ pred_coords = pred_outputs[:, :, :2]
+ pred_sigma = pred_outputs[:, :, 2:4]
+
+ # calculate losses
+ losses = dict()
+ loss = self.loss_module(pred_coords, pred_sigma, keypoint_labels,
+ keypoint_weights.unsqueeze(-1))
+
+ losses.update(loss_kpt=loss)
+
+ # calculate accuracy
+ _, avg_acc, _ = keypoint_pck_accuracy(
+ pred=to_numpy(pred_coords),
+ gt=to_numpy(keypoint_labels),
+ mask=to_numpy(keypoint_weights) > 0,
+ thr=0.05,
+ norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32))
+
+ acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device)
+ losses.update(acc_pose=acc_pose)
+
+ return losses
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert old-version state dict of
+ :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
+ compatible format of :class:`HeatmapHead`.
+
+ The hook will be automatically registered during initialization.
+ """
+
+ version = local_meta.get('version', None)
+ if version and version >= self._version:
+ return
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for _k in keys:
+ v = state_dict.pop(_k)
+ k = _k.lstrip(prefix)
+ # In old version, "loss" includes the instances of loss,
+ # now it should be renamed "loss_module"
+ k_parts = k.split('.')
+ if k_parts[0] == 'loss':
+ # loss.xxx -> loss_module.xxx
+ k_new = prefix + 'loss_module.' + '.'.join(k_parts[1:])
+ else:
+ k_new = _k
+
+ state_dict[k_new] = v
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
+ return init_cfg
diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21071e15699c42ca21f49e62fe0fb9f2869a68e
--- /dev/null
+++ b/mmpose/models/losses/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ae_loss import AssociativeEmbeddingLoss
+from .classification_loss import BCELoss, JSDiscretLoss, KLDiscretLoss
+from .heatmap_loss import (AdaptiveWingLoss, KeypointMSELoss,
+ KeypointOHKMMSELoss)
+from .loss_wrappers import CombinedLoss, MultipleLossWrapper
+from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss,
+ SemiSupervisionLoss, SmoothL1Loss,
+ SoftWeightSmoothL1Loss, SoftWingLoss, WingLoss)
+
+__all__ = [
+ 'KeypointMSELoss', 'KeypointOHKMMSELoss', 'SmoothL1Loss', 'WingLoss',
+ 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss',
+ 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss',
+ 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', 'CombinedLoss',
+ 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss'
+]
diff --git a/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc b/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21c60a281882de129a8e5b5ab33bc9d4d34e7ca8
Binary files /dev/null and b/mmpose/models/losses/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53d5c9ad4ef2e04d19925d834f442240e3cadd82
Binary files /dev/null and b/mmpose/models/losses/__pycache__/ae_loss.cpython-38.pyc differ
diff --git a/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea4cb8ac587ce0210fed22c097e51f079fd7e250
Binary files /dev/null and b/mmpose/models/losses/__pycache__/classification_loss.cpython-38.pyc differ
diff --git a/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..023ffe15dd8dc4bfe72cfcb1d236b4c953055a1f
Binary files /dev/null and b/mmpose/models/losses/__pycache__/heatmap_loss.cpython-38.pyc differ
diff --git a/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc b/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..198ff7f4b4baaf276f3cf2215f7e432eefb666ac
Binary files /dev/null and b/mmpose/models/losses/__pycache__/loss_wrappers.cpython-38.pyc differ
diff --git a/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc b/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76a5c14f73e6f0e97fe8985e7e218624c8d0b524
Binary files /dev/null and b/mmpose/models/losses/__pycache__/regression_loss.cpython-38.pyc differ
diff --git a/mmpose/models/losses/ae_loss.py b/mmpose/models/losses/ae_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f1e08181beaf835238596d95fe509b122c64b3d
--- /dev/null
+++ b/mmpose/models/losses/ae_loss.py
@@ -0,0 +1,123 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import List, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class AssociativeEmbeddingLoss(nn.Module):
+ """Associative Embedding loss.
+
+ Details can be found in
+ `Associative Embedding `_
+
+ Note:
+
+ - batch size: B
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - embedding tag dimension: L
+ - heatmap size: [W, H]
+
+ Args:
+ loss_weight (float): Weight of the loss. Defaults to 1.0
+ push_loss_factor (float): A factor that controls the weight between
+ the push loss and the pull loss. Defaults to 0.5
+ """
+
+ def __init__(self,
+ loss_weight: float = 1.0,
+ push_loss_factor: float = 0.5) -> None:
+ super().__init__()
+ self.loss_weight = loss_weight
+ self.push_loss_factor = push_loss_factor
+
+ def _ae_loss_per_image(self, tags: Tensor, keypoint_indices: Tensor):
+ """Compute associative embedding loss for one image.
+
+ Args:
+ tags (Tensor): Tagging heatmaps in shape (K*L, H, W)
+ keypoint_indices (Tensor): Ground-truth keypint position indices
+ in shape (N, K, 2)
+ """
+ K = keypoint_indices.shape[1]
+ C, H, W = tags.shape
+ L = C // K
+
+ tags = tags.view(L, K, H * W)
+ instance_tags = []
+ instance_kpt_tags = []
+
+ for keypoint_indices_n in keypoint_indices:
+ _kpt_tags = []
+ for k in range(K):
+ if keypoint_indices_n[k, 1]:
+ _kpt_tags.append(tags[:, k, keypoint_indices_n[k, 0]])
+
+ if _kpt_tags:
+ kpt_tags = torch.stack(_kpt_tags)
+ instance_kpt_tags.append(kpt_tags)
+ instance_tags.append(kpt_tags.mean(dim=0))
+
+ N = len(instance_kpt_tags) # number of instances with valid keypoints
+
+ if N == 0:
+ pull_loss = tags.new_zeros(size=(), requires_grad=True)
+ push_loss = tags.new_zeros(size=(), requires_grad=True)
+ else:
+ pull_loss = sum(
+ F.mse_loss(_kpt_tags, _tag.expand_as(_kpt_tags))
+ for (_kpt_tags, _tag) in zip(instance_kpt_tags, instance_tags))
+
+ if N == 1:
+ push_loss = tags.new_zeros(size=(), requires_grad=True)
+ else:
+ tag_mat = torch.stack(instance_tags) # (N, L)
+ diff = tag_mat[None] - tag_mat[:, None] # (N, N, L)
+ push_loss = torch.sum(torch.exp(-diff.pow(2)))
+
+ # normalization
+ eps = 1e-6
+ pull_loss = pull_loss / (N + eps)
+ push_loss = push_loss / ((N - 1) * N + eps)
+
+ return pull_loss, push_loss
+
+ def forward(self, tags: Tensor, keypoint_indices: Union[List[Tensor],
+ Tensor]):
+ """Compute associative embedding loss on a batch of data.
+
+ Args:
+ tags (Tensor): Tagging heatmaps in shape (B, L*K, H, W)
+ keypoint_indices (Tensor|List[Tensor]): Ground-truth keypint
+ position indices represented by a Tensor in shape
+ (B, N, K, 2), or a list of B Tensors in shape (N_i, K, 2)
+ Each keypoint's index is represented as [i, v], where i is the
+ position index in the heatmap (:math:`i=y*w+x`) and v is the
+ visibility
+
+ Returns:
+ tuple:
+ - pull_loss (Tensor)
+ - push_loss (Tensor)
+ """
+
+ assert tags.shape[0] == len(keypoint_indices)
+
+ pull_loss = 0.
+ push_loss = 0.
+
+ for i in range(tags.shape[0]):
+ _pull, _push = self._ae_loss_per_image(tags[i],
+ keypoint_indices[i])
+ pull_loss += _pull * self.loss_weight
+ push_loss += _push * self.loss_weight * self.push_loss_factor
+
+ return pull_loss, push_loss
diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3bdf502b7f973d8fe9c2217faa7e6ad3a5c849
--- /dev/null
+++ b/mmpose/models/losses/classification_loss.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class BCELoss(nn.Module):
+ """Binary Cross Entropy loss.
+
+ Args:
+ use_target_weight (bool): Option to use weighted loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.criterion = F.binary_cross_entropy
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_labels: K
+
+ Args:
+ output (torch.Tensor[N, K]): Output classification.
+ target (torch.Tensor[N, K]): Target classification.
+ target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
+ Weights across different labels.
+ """
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = self.criterion(output, target, reduction='none')
+ if target_weight.dim() == 1:
+ target_weight = target_weight[:, None]
+ loss = (loss * target_weight).mean()
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class JSDiscretLoss(nn.Module):
+ """Discrete JS Divergence loss for DSNT with Gaussian Heatmap.
+
+ Modified from `the official implementation
+ `_.
+
+ Args:
+ use_target_weight (bool): Option to use weighted loss.
+ Different joint types may have different target weights.
+ size_average (bool): Option to average the loss by the batch_size.
+ """
+
+ def __init__(
+ self,
+ use_target_weight=True,
+ size_average: bool = True,
+ ):
+ super(JSDiscretLoss, self).__init__()
+ self.use_target_weight = use_target_weight
+ self.size_average = size_average
+ self.kl_loss = nn.KLDivLoss(reduction='none')
+
+ def kl(self, p, q):
+ """Kullback-Leibler Divergence."""
+
+ eps = 1e-24
+ kl_values = self.kl_loss((q + eps).log(), p)
+ return kl_values
+
+ def js(self, pred_hm, gt_hm):
+ """Jensen-Shannon Divergence."""
+
+ m = 0.5 * (pred_hm + gt_hm)
+ js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m))
+ return js_values
+
+ def forward(self, pred_hm, gt_hm, target_weight=None):
+ """Forward function.
+
+ Args:
+ pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps.
+ gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps.
+ target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
+ Weights across different labels.
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ assert pred_hm.ndim >= target_weight.ndim
+
+ for i in range(pred_hm.ndim - target_weight.ndim):
+ target_weight = target_weight.unsqueeze(-1)
+
+ loss = self.js(pred_hm * target_weight, gt_hm * target_weight)
+ else:
+ loss = self.js(pred_hm, gt_hm)
+
+ if self.size_average:
+ loss /= len(gt_hm)
+
+ return loss.sum()
+
+
+@MODELS.register_module()
+class KLDiscretLoss(nn.Module):
+ """Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing.
+ Modified from `the official implementation.
+
+ `_.
+ Args:
+ beta (float): Temperature factor of Softmax.
+ label_softmax (bool): Whether to use Softmax on labels.
+ use_target_weight (bool): Option to use weighted loss.
+ Different joint types may have different target weights.
+ """
+
+ def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True):
+ super(KLDiscretLoss, self).__init__()
+ self.beta = beta
+ self.label_softmax = label_softmax
+ self.use_target_weight = use_target_weight
+
+ self.log_softmax = nn.LogSoftmax(dim=1)
+ self.kl_loss = nn.KLDivLoss(reduction='none')
+
+ def criterion(self, dec_outs, labels):
+ """Criterion function."""
+ log_pt = self.log_softmax(dec_outs * self.beta)
+ if self.label_softmax:
+ labels = F.softmax(labels * self.beta, dim=1)
+ loss = torch.mean(self.kl_loss(log_pt, labels), dim=1)
+ return loss
+
+ def forward(self, pred_simcc, gt_simcc, target_weight):
+ """Forward function.
+
+ Args:
+ pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of
+ x-axis and y-axis.
+ gt_simcc (Tuple[Tensor, Tensor]): Target representations.
+ target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
+ Weights across different labels.
+ """
+ num_joints = pred_simcc[0].size(1)
+ loss = 0
+
+ if self.use_target_weight:
+ weight = target_weight.reshape(-1)
+ else:
+ weight = 1.
+
+ for pred, target in zip(pred_simcc, gt_simcc):
+ pred = pred.reshape(-1, pred.size(-1))
+ target = target.reshape(-1, target.size(-1))
+
+ loss += self.criterion(pred, target).mul(weight).sum()
+
+ return loss / num_joints
+
+
+@MODELS.register_module()
+class InfoNCELoss(nn.Module):
+ """InfoNCE loss for training a discriminative representation space with a
+ contrastive manner.
+
+ `Representation Learning with Contrastive Predictive Coding
+ arXiv: `_.
+
+ Args:
+ temperature (float, optional): The temperature to use in the softmax
+ function. Higher temperatures lead to softer probability
+ distributions. Defaults to 1.0.
+ loss_weight (float, optional): The weight to apply to the loss.
+ Defaults to 1.0.
+ """
+
+ def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None:
+ super(InfoNCELoss, self).__init__()
+ assert temperature > 0, f'the argument `temperature` must be ' \
+ f'positive, but got {temperature}'
+ self.temp = temperature
+ self.loss_weight = loss_weight
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """Computes the InfoNCE loss.
+
+ Args:
+ features (Tensor): A tensor containing the feature
+ representations of different samples.
+
+ Returns:
+ Tensor: A tensor of shape (1,) containing the InfoNCE loss.
+ """
+ n = features.size(0)
+ features_norm = F.normalize(features, dim=1)
+ logits = features_norm.mm(features_norm.t()) / self.temp
+ targets = torch.arange(n, dtype=torch.long, device=features.device)
+ loss = F.cross_entropy(logits, targets, reduction='sum')
+ return loss * self.loss_weight
diff --git a/mmpose/models/losses/heatmap_loss.py b/mmpose/models/losses/heatmap_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1051494683c817d7d14a73f2ceebed67834d64b
--- /dev/null
+++ b/mmpose/models/losses/heatmap_loss.py
@@ -0,0 +1,455 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class KeypointMSELoss(nn.Module):
+ """MSE loss for heatmaps.
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ Defaults to ``False``
+ skip_empty_channel (bool): If ``True``, heatmap channels with no
+ non-zero value (which means no visible ground-truth keypoint
+ in the image) will not be used to calculate the loss. Defaults to
+ ``False``
+ loss_weight (float): Weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ use_target_weight: bool = False,
+ skip_empty_channel: bool = False,
+ loss_weight: float = 1.):
+ super().__init__()
+ self.use_target_weight = use_target_weight
+ self.skip_empty_channel = skip_empty_channel
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ output: Tensor,
+ target: Tensor,
+ target_weights: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None) -> Tensor:
+ """Forward function of loss.
+
+ Note:
+ - batch_size: B
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+
+ Args:
+ output (Tensor): The output heatmaps with shape [B, K, H, W]
+ target (Tensor): The target heatmaps with shape [B, K, H, W]
+ target_weights (Tensor, optional): The target weights of differet
+ keypoints, with shape [B, K] (keypoint-wise) or
+ [B, K, H, W] (pixel-wise).
+ mask (Tensor, optional): The masks of valid heatmap pixels in
+ shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will
+ be applied. Defaults to ``None``
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+
+ _mask = self._get_mask(target, target_weights, mask)
+ if _mask is None:
+ loss = F.mse_loss(output, target)
+ else:
+ _loss = F.mse_loss(output, target, reduction='none')
+ loss = (_loss * _mask).mean()
+
+ return loss * self.loss_weight
+
+ def _get_mask(self, target: Tensor, target_weights: Optional[Tensor],
+ mask: Optional[Tensor]) -> Optional[Tensor]:
+ """Generate the heatmap mask w.r.t. the given mask, target weight and
+ `skip_empty_channel` setting.
+
+ Returns:
+ Tensor: The mask in shape (B, K, *) or ``None`` if no mask is
+ needed.
+ """
+ # Given spatial mask
+ if mask is not None:
+ # check mask has matching type with target
+ assert (mask.ndim == target.ndim and all(
+ d_m == d_t or d_m == 1
+ for d_m, d_t in zip(mask.shape, target.shape))), (
+ f'mask and target have mismatched shapes {mask.shape} v.s.'
+ f'{target.shape}')
+
+ # Mask by target weights (keypoint-wise mask)
+ if target_weights is not None:
+ # check target weight has matching shape with target
+ assert (target_weights.ndim in (2, 4) and target_weights.shape
+ == target.shape[:target_weights.ndim]), (
+ 'target_weights and target have mismatched shapes '
+ f'{target_weights.shape} v.s. {target.shape}')
+
+ ndim_pad = target.ndim - target_weights.ndim
+ _mask = target_weights.view(target_weights.shape +
+ (1, ) * ndim_pad)
+
+ if mask is None:
+ mask = _mask
+ else:
+ mask = mask * _mask
+
+ # Mask by ``skip_empty_channel``
+ if self.skip_empty_channel:
+ _mask = (target != 0).flatten(2).any()
+ ndim_pad = target.ndim - _mask.ndim
+ _mask = _mask.view(_mask.shape + (1, ) * ndim_pad)
+
+ if mask is None:
+ mask = _mask
+ else:
+ mask = mask * _mask
+
+ return mask
+
+
+@MODELS.register_module()
+class CombinedTargetMSELoss(nn.Module):
+ """MSE loss for combined target.
+
+ CombinedTarget: The combination of classification target
+ (response map) and regression target (offset map).
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ Defaults to ``False``
+ loss_weight (float): Weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ use_target_weight: bool = False,
+ loss_weight: float = 1.):
+ super().__init__()
+ self.criterion = nn.MSELoss(reduction='mean')
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output: Tensor, target: Tensor,
+ target_weights: Tensor) -> Tensor:
+ """Forward function of loss.
+
+ Note:
+ - batch_size: B
+ - num_channels: C
+ - heatmaps height: H
+ - heatmaps weight: W
+ - num_keypoints: K
+ Here, C = 3 * K
+
+ Args:
+ output (Tensor): The output feature maps with shape [B, C, H, W].
+ target (Tensor): The target feature maps with shape [B, C, H, W].
+ target_weights (Tensor): The target weights of differet keypoints,
+ with shape [B, K].
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+ batch_size = output.size(0)
+ num_channels = output.size(1)
+ heatmaps_pred = output.reshape(
+ (batch_size, num_channels, -1)).split(1, 1)
+ heatmaps_gt = target.reshape(
+ (batch_size, num_channels, -1)).split(1, 1)
+ loss = 0.
+ num_joints = num_channels // 3
+ for idx in range(num_joints):
+ heatmap_pred = heatmaps_pred[idx * 3].squeeze()
+ heatmap_gt = heatmaps_gt[idx * 3].squeeze()
+ offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
+ offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
+ offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
+ offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
+ if self.use_target_weight:
+ target_weight = target_weights[:, idx, None]
+ heatmap_pred = heatmap_pred * target_weight
+ heatmap_gt = heatmap_gt * target_weight
+ # classification loss
+ loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
+ # regression loss
+ loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
+ heatmap_gt * offset_x_gt)
+ loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
+ heatmap_gt * offset_y_gt)
+ return loss / num_joints * self.loss_weight
+
+
+@MODELS.register_module()
+class KeypointOHKMMSELoss(nn.Module):
+ """MSE loss with online hard keypoint mining.
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ Defaults to ``False``
+ topk (int): Only top k joint losses are kept. Defaults to 8
+ loss_weight (float): Weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ use_target_weight: bool = False,
+ topk: int = 8,
+ loss_weight: float = 1.):
+ super().__init__()
+ assert topk > 0
+ self.criterion = nn.MSELoss(reduction='none')
+ self.use_target_weight = use_target_weight
+ self.topk = topk
+ self.loss_weight = loss_weight
+
+ def _ohkm(self, losses: Tensor) -> Tensor:
+ """Online hard keypoint mining.
+
+ Note:
+ - batch_size: B
+ - num_keypoints: K
+
+ Args:
+ loss (Tensor): The losses with shape [B, K]
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+ ohkm_loss = 0.
+ B = losses.shape[0]
+ for i in range(B):
+ sub_loss = losses[i]
+ _, topk_idx = torch.topk(
+ sub_loss, k=self.topk, dim=0, sorted=False)
+ tmp_loss = torch.gather(sub_loss, 0, topk_idx)
+ ohkm_loss += torch.sum(tmp_loss) / self.topk
+ ohkm_loss /= B
+ return ohkm_loss
+
+ def forward(self, output: Tensor, target: Tensor,
+ target_weights: Tensor) -> Tensor:
+ """Forward function of loss.
+
+ Note:
+ - batch_size: B
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+
+ Args:
+ output (Tensor): The output heatmaps with shape [B, K, H, W].
+ target (Tensor): The target heatmaps with shape [B, K, H, W].
+ target_weights (Tensor): The target weights of differet keypoints,
+ with shape [B, K].
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+ num_keypoints = output.size(1)
+ if num_keypoints < self.topk:
+ raise ValueError(f'topk ({self.topk}) should not be '
+ f'larger than num_keypoints ({num_keypoints}).')
+
+ losses = []
+ for idx in range(num_keypoints):
+ if self.use_target_weight:
+ target_weight = target_weights[:, idx, None, None]
+ losses.append(
+ self.criterion(output[:, idx] * target_weight,
+ target[:, idx] * target_weight))
+ else:
+ losses.append(self.criterion(output[:, idx], target[:, idx]))
+
+ losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses]
+ losses = torch.cat(losses, dim=1)
+
+ return self._ohkm(losses) * self.loss_weight
+
+
+@MODELS.register_module()
+class AdaptiveWingLoss(nn.Module):
+ """Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
+ Alignment via Heatmap Regression' Wang et al. ICCV'2019.
+
+ Args:
+ alpha (float), omega (float), epsilon (float), theta (float)
+ are hyper-parameters.
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self,
+ alpha=2.1,
+ omega=14,
+ epsilon=1,
+ theta=0.5,
+ use_target_weight=False,
+ loss_weight=1.):
+ super().__init__()
+ self.alpha = float(alpha)
+ self.omega = float(omega)
+ self.epsilon = float(epsilon)
+ self.theta = float(theta)
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def criterion(self, pred, target):
+ """Criterion of wingloss.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+
+ Args:
+ pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
+ target (torch.Tensor[NxKxHxW]): Target heatmaps.
+ """
+ H, W = pred.shape[2:4]
+ delta = (target - pred).abs()
+
+ A = self.omega * (
+ 1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
+ ) * (self.alpha - target) * (torch.pow(
+ self.theta / self.epsilon,
+ self.alpha - target - 1)) * (1 / self.epsilon)
+ C = self.theta * A - self.omega * torch.log(
+ 1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
+
+ losses = torch.where(
+ delta < self.theta,
+ self.omega *
+ torch.log(1 +
+ torch.pow(delta / self.epsilon, self.alpha - target)),
+ A * delta - C)
+
+ return torch.mean(losses)
+
+ def forward(self,
+ output: Tensor,
+ target: Tensor,
+ target_weights: Optional[Tensor] = None):
+ """Forward function.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+
+ Args:
+ output (torch.Tensor[N, K, H, W]): Output heatmaps.
+ target (torch.Tensor[N, K, H, W]): Target heatmaps.
+ target_weight (torch.Tensor[N, K]):
+ Weights across different joint types.
+ """
+ if self.use_target_weight:
+ assert (target_weights.ndim in (2, 4) and target_weights.shape
+ == target.shape[:target_weights.ndim]), (
+ 'target_weights and target have mismatched shapes '
+ f'{target_weights.shape} v.s. {target.shape}')
+
+ ndim_pad = target.ndim - target_weights.ndim
+ target_weights = target_weights.view(target_weights.shape +
+ (1, ) * ndim_pad)
+ loss = self.criterion(output * target_weights,
+ target * target_weights)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class FocalHeatmapLoss(KeypointMSELoss):
+ """A class for calculating the modified focal loss for heatmap prediction.
+
+ This loss function is exactly the same as the one used in CornerNet. It
+ runs faster and costs a little bit more memory.
+
+ `CornerNet: Detecting Objects as Paired Keypoints
+ arXiv: `_.
+
+ Arguments:
+ alpha (int): The alpha parameter in the focal loss equation.
+ beta (int): The beta parameter in the focal loss equation.
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ Defaults to ``False``
+ skip_empty_channel (bool): If ``True``, heatmap channels with no
+ non-zero value (which means no visible ground-truth keypoint
+ in the image) will not be used to calculate the loss. Defaults to
+ ``False``
+ loss_weight (float): Weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ alpha: int = 2,
+ beta: int = 4,
+ use_target_weight: bool = False,
+ skip_empty_channel: bool = False,
+ loss_weight: float = 1.0):
+ super(FocalHeatmapLoss, self).__init__(use_target_weight,
+ skip_empty_channel, loss_weight)
+ self.alpha = alpha
+ self.beta = beta
+
+ def forward(self,
+ output: Tensor,
+ target: Tensor,
+ target_weights: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None) -> Tensor:
+ """Calculate the modified focal loss for heatmap prediction.
+
+ Note:
+ - batch_size: B
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+
+ Args:
+ output (Tensor): The output heatmaps with shape [B, K, H, W]
+ target (Tensor): The target heatmaps with shape [B, K, H, W]
+ target_weights (Tensor, optional): The target weights of differet
+ keypoints, with shape [B, K] (keypoint-wise) or
+ [B, K, H, W] (pixel-wise).
+ mask (Tensor, optional): The masks of valid heatmap pixels in
+ shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will
+ be applied. Defaults to ``None``
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+ _mask = self._get_mask(target, target_weights, mask)
+
+ pos_inds = target.eq(1).float()
+ neg_inds = target.lt(1).float()
+
+ if _mask is not None:
+ pos_inds = pos_inds * _mask
+ neg_inds = neg_inds * _mask
+
+ neg_weights = torch.pow(1 - target, self.beta)
+
+ pos_loss = torch.log(output) * torch.pow(1 - output,
+ self.alpha) * pos_inds
+ neg_loss = torch.log(1 - output) * torch.pow(
+ output, self.alpha) * neg_weights * neg_inds
+
+ num_pos = pos_inds.float().sum()
+ if num_pos == 0:
+ loss = -neg_loss.sum()
+ else:
+ loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos
+ return loss * self.loss_weight
diff --git a/mmpose/models/losses/loss_wrappers.py b/mmpose/models/losses/loss_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d821661b48a133ffd6c9232d5a6a2d3eb6bf0a50
--- /dev/null
+++ b/mmpose/models/losses/loss_wrappers.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict
+
+import torch.nn as nn
+
+from mmpose.registry import MODELS
+from mmpose.utils.typing import ConfigType
+
+
+@MODELS.register_module()
+class MultipleLossWrapper(nn.Module):
+ """A wrapper to collect multiple loss functions together and return a list
+ of losses in the same order.
+
+ Args:
+ losses (list): List of Loss Config
+ """
+
+ def __init__(self, losses: list):
+ super().__init__()
+ self.num_losses = len(losses)
+
+ loss_modules = []
+ for loss_cfg in losses:
+ t_loss = MODELS.build(loss_cfg)
+ loss_modules.append(t_loss)
+ self.loss_modules = nn.ModuleList(loss_modules)
+
+ def forward(self, input_list, target_list, keypoint_weights=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ input_list (List[Tensor]): List of inputs.
+ target_list (List[Tensor]): List of targets.
+ keypoint_weights (Tensor[N, K, D]):
+ Weights across different joint types.
+ """
+ assert isinstance(input_list, list), ''
+ assert isinstance(target_list, list), ''
+ assert len(input_list) == len(target_list), ''
+
+ losses = []
+ for i in range(self.num_losses):
+ input_i = input_list[i]
+ target_i = target_list[i]
+
+ loss_i = self.loss_modules[i](input_i, target_i, keypoint_weights)
+ losses.append(loss_i)
+
+ return losses
+
+
+@MODELS.register_module()
+class CombinedLoss(nn.ModuleDict):
+ """A wrapper to combine multiple loss functions. These loss functions can
+ have different input type (e.g. heatmaps or regression values), and can
+ only be involed individually and explixitly.
+
+ Args:
+ losses (Dict[str, ConfigType]): The names and configs of loss
+ functions to be wrapped
+
+ Example::
+ >>> heatmap_loss_cfg = dict(type='KeypointMSELoss')
+ >>> ae_loss_cfg = dict(type='AssociativeEmbeddingLoss')
+ >>> loss_module = CombinedLoss(
+ ... losses=dict(
+ ... heatmap_loss=heatmap_loss_cfg,
+ ... ae_loss=ae_loss_cfg))
+ >>> loss_hm = loss_module.heatmap_loss(pred_heatmap, gt_heatmap)
+ >>> loss_ae = loss_module.ae_loss(pred_tags, keypoint_indices)
+ """
+
+ def __init__(self, losses: Dict[str, ConfigType]):
+ super().__init__()
+ for loss_name, loss_cfg in losses.items():
+ self.add_module(loss_name, MODELS.build(loss_cfg))
diff --git a/mmpose/models/losses/regression_loss.py b/mmpose/models/losses/regression_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a64a4adfe9c7429d6e72b794d6c2b01af14fe46
--- /dev/null
+++ b/mmpose/models/losses/regression_loss.py
@@ -0,0 +1,618 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmpose.registry import MODELS
+from ..utils.realnvp import RealNVP
+
+
+@MODELS.register_module()
+class RLELoss(nn.Module):
+ """RLE Loss.
+
+ `Human Pose Regression With Residual Log-Likelihood Estimation
+ arXiv: `_.
+
+ Code is modified from `the official implementation
+ `_.
+
+ Args:
+ use_target_weight (bool): Option to use weighted loss.
+ Different joint types may have different target weights.
+ size_average (bool): Option to average the loss by the batch_size.
+ residual (bool): Option to add L1 loss and let the flow
+ learn the residual error distribution.
+ q_dis (string): Option for the identity Q(error) distribution,
+ Options: "laplace" or "gaussian"
+ """
+
+ def __init__(self,
+ use_target_weight=False,
+ size_average=True,
+ residual=True,
+ q_distribution='laplace'):
+ super(RLELoss, self).__init__()
+ self.size_average = size_average
+ self.use_target_weight = use_target_weight
+ self.residual = residual
+ self.q_distribution = q_distribution
+
+ self.flow_model = RealNVP()
+
+ def forward(self, pred, sigma, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ pred (Tensor[N, K, D]): Output regression.
+ sigma (Tensor[N, K, D]): Output sigma.
+ target (Tensor[N, K, D]): Target regression.
+ target_weight (Tensor[N, K, D]):
+ Weights across different joint types.
+ """
+ sigma = sigma.sigmoid()
+
+ error = (pred - target) / (sigma + 1e-9)
+ # (B, K, 2)
+ log_phi = self.flow_model.log_prob(error.reshape(-1, 2))
+ log_phi = log_phi.reshape(target.shape[0], target.shape[1], 1)
+ log_sigma = torch.log(sigma).reshape(target.shape[0], target.shape[1],
+ 2)
+ nf_loss = log_sigma - log_phi
+
+ if self.residual:
+ assert self.q_distribution in ['laplace', 'gaussian']
+ if self.q_distribution == 'laplace':
+ loss_q = torch.log(sigma * 2) + torch.abs(error)
+ else:
+ loss_q = torch.log(
+ sigma * math.sqrt(2 * math.pi)) + 0.5 * error**2
+
+ loss = nf_loss + loss_q
+ else:
+ loss = nf_loss
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss *= target_weight
+
+ if self.size_average:
+ loss /= len(loss)
+
+ return loss.sum()
+
+
+@MODELS.register_module()
+class SmoothL1Loss(nn.Module):
+ """SmoothL1Loss loss.
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.criterion = F.smooth_l1_loss
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N, K, D]):
+ Weights across different joint types.
+ """
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ assert output.ndim >= target_weight.ndim
+
+ for i in range(output.ndim - target_weight.ndim):
+ target_weight = target_weight.unsqueeze(-1)
+
+ loss = self.criterion(output * target_weight,
+ target * target_weight)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class SoftWeightSmoothL1Loss(nn.Module):
+ """Smooth L1 loss with soft weight for regression.
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ supervise_empty (bool): Whether to supervise the output with zero
+ weight.
+ beta (float): Specifies the threshold at which to change between
+ L1 and L2 loss.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self,
+ use_target_weight=False,
+ supervise_empty=True,
+ beta=1.0,
+ loss_weight=1.):
+ super().__init__()
+
+ reduction = 'none' if use_target_weight else 'mean'
+ self.criterion = partial(
+ self.smooth_l1_loss, reduction=reduction, beta=beta)
+
+ self.supervise_empty = supervise_empty
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ @staticmethod
+ def smooth_l1_loss(input, target, reduction='none', beta=1.0):
+ """Re-implement torch.nn.functional.smooth_l1_loss with beta to support
+ pytorch <= 1.6."""
+ delta = input - target
+ mask = delta.abs() < beta
+ delta[mask] = (delta[mask]).pow(2) / (2 * beta)
+ delta[~mask] = delta[~mask].abs() - beta / 2
+
+ if reduction == 'mean':
+ return delta.mean()
+ elif reduction == 'sum':
+ return delta.sum()
+ elif reduction == 'none':
+ return delta
+ else:
+ raise ValueError(f'reduction must be \'mean\', \'sum\' or '
+ f'\'none\', but got \'{reduction}\'')
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N, K, D]):
+ Weights across different joint types.
+ """
+ if self.use_target_weight:
+ assert target_weight is not None
+ assert output.ndim >= target_weight.ndim
+
+ for i in range(output.ndim - target_weight.ndim):
+ target_weight = target_weight.unsqueeze(-1)
+
+ loss = self.criterion(output, target) * target_weight
+ if self.supervise_empty:
+ loss = loss.mean()
+ else:
+ num_elements = torch.nonzero(target_weight > 0).size()[0]
+ loss = loss.sum() / max(num_elements, 1.0)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class WingLoss(nn.Module):
+ """Wing Loss. paper ref: 'Wing Loss for Robust Facial Landmark Localisation
+ with Convolutional Neural Networks' Feng et al. CVPR'2018.
+
+ Args:
+ omega (float): Also referred to as width.
+ epsilon (float): Also referred to as curvature.
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self,
+ omega=10.0,
+ epsilon=2.0,
+ use_target_weight=False,
+ loss_weight=1.):
+ super().__init__()
+ self.omega = omega
+ self.epsilon = epsilon
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ # constant that smoothly links the piecewise-defined linear
+ # and nonlinear parts
+ self.C = self.omega * (1.0 - math.log(1.0 + self.omega / self.epsilon))
+
+ def criterion(self, pred, target):
+ """Criterion of wingloss.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ pred (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ """
+ delta = (target - pred).abs()
+ losses = torch.where(
+ delta < self.omega,
+ self.omega * torch.log(1.0 + delta / self.epsilon), delta - self.C)
+ return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N,K,D]):
+ Weights across different joint types.
+ """
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = self.criterion(output * target_weight,
+ target * target_weight)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class SoftWingLoss(nn.Module):
+ """Soft Wing Loss 'Structure-Coherent Deep Feature Learning for Robust Face
+ Alignment' Lin et al. TIP'2021.
+
+ loss =
+ 1. |x| , if |x| < omega1
+ 2. omega2*ln(1+|x|/epsilon) + B, if |x| >= omega1
+
+ Args:
+ omega1 (float): The first threshold.
+ omega2 (float): The second threshold.
+ epsilon (float): Also referred to as curvature.
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self,
+ omega1=2.0,
+ omega2=20.0,
+ epsilon=0.5,
+ use_target_weight=False,
+ loss_weight=1.):
+ super().__init__()
+ self.omega1 = omega1
+ self.omega2 = omega2
+ self.epsilon = epsilon
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ # constant that smoothly links the piecewise-defined linear
+ # and nonlinear parts
+ self.B = self.omega1 - self.omega2 * math.log(1.0 + self.omega1 /
+ self.epsilon)
+
+ def criterion(self, pred, target):
+ """Criterion of wingloss.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ pred (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ """
+ delta = (target - pred).abs()
+ losses = torch.where(
+ delta < self.omega1, delta,
+ self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B)
+ return torch.mean(torch.sum(losses, dim=[1, 2]), dim=0)
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ batch_size: N
+ num_keypoints: K
+ dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N, K, D]):
+ Weights across different joint types.
+ """
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = self.criterion(output * target_weight,
+ target * target_weight)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class MPJPELoss(nn.Module):
+ """MPJPE (Mean Per Joint Position Error) loss.
+
+ Args:
+ use_target_weight (bool): Option to use weighted MSE loss.
+ Different joint types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N,K,D]):
+ Weights across different joint types.
+ """
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = torch.mean(
+ torch.norm((output - target) * target_weight, dim=-1))
+ else:
+ loss = torch.mean(torch.norm(output - target, dim=-1))
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class L1Loss(nn.Module):
+ """L1Loss loss ."""
+
+ def __init__(self, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.criterion = F.l1_loss
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+
+ Args:
+ output (torch.Tensor[N, K, 2]): Output regression.
+ target (torch.Tensor[N, K, 2]): Target regression.
+ target_weight (torch.Tensor[N, K, 2]):
+ Weights across different joint types.
+ """
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = self.criterion(output * target_weight,
+ target * target_weight)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class MSELoss(nn.Module):
+ """MSE loss for coordinate regression."""
+
+ def __init__(self, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.criterion = F.mse_loss
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+
+ Args:
+ output (torch.Tensor[N, K, 2]): Output regression.
+ target (torch.Tensor[N, K, 2]): Target regression.
+ target_weight (torch.Tensor[N, K, 2]):
+ Weights across different joint types.
+ """
+
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = self.criterion(output * target_weight,
+ target * target_weight)
+ else:
+ loss = self.criterion(output, target)
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class BoneLoss(nn.Module):
+ """Bone length loss.
+
+ Args:
+ joint_parents (list): Indices of each joint's parent joint.
+ use_target_weight (bool): Option to use weighted bone loss.
+ Different bone types may have different target weights.
+ loss_weight (float): Weight of the loss. Default: 1.0.
+ """
+
+ def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
+ super().__init__()
+ self.joint_parents = joint_parents
+ self.use_target_weight = use_target_weight
+ self.loss_weight = loss_weight
+
+ self.non_root_indices = []
+ for i in range(len(self.joint_parents)):
+ if i != self.joint_parents[i]:
+ self.non_root_indices.append(i)
+
+ def forward(self, output, target, target_weight=None):
+ """Forward function.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - dimension of keypoints: D (D=2 or D=3)
+
+ Args:
+ output (torch.Tensor[N, K, D]): Output regression.
+ target (torch.Tensor[N, K, D]): Target regression.
+ target_weight (torch.Tensor[N, K-1]):
+ Weights across different bone types.
+ """
+ output_bone = torch.norm(
+ output - output[:, self.joint_parents, :],
+ dim=-1)[:, self.non_root_indices]
+ target_bone = torch.norm(
+ target - target[:, self.joint_parents, :],
+ dim=-1)[:, self.non_root_indices]
+ if self.use_target_weight:
+ assert target_weight is not None
+ loss = torch.mean(
+ torch.abs((output_bone * target_weight).mean(dim=0) -
+ (target_bone * target_weight).mean(dim=0)))
+ else:
+ loss = torch.mean(
+ torch.abs(output_bone.mean(dim=0) - target_bone.mean(dim=0)))
+
+ return loss * self.loss_weight
+
+
+@MODELS.register_module()
+class SemiSupervisionLoss(nn.Module):
+ """Semi-supervision loss for unlabeled data. It is composed of projection
+ loss and bone loss.
+
+ Paper ref: `3D human pose estimation in video with temporal convolutions
+ and semi-supervised training` Dario Pavllo et al. CVPR'2019.
+
+ Args:
+ joint_parents (list): Indices of each joint's parent joint.
+ projection_loss_weight (float): Weight for projection loss.
+ bone_loss_weight (float): Weight for bone loss.
+ warmup_iterations (int): Number of warmup iterations. In the first
+ `warmup_iterations` iterations, the model is trained only on
+ labeled data, and semi-supervision loss will be 0.
+ This is a workaround since currently we cannot access
+ epoch number in loss functions. Note that the iteration number in
+ an epoch can be changed due to different GPU numbers in multi-GPU
+ settings. So please set this parameter carefully.
+ warmup_iterations = dataset_size // samples_per_gpu // gpu_num
+ * warmup_epochs
+ """
+
+ def __init__(self,
+ joint_parents,
+ projection_loss_weight=1.,
+ bone_loss_weight=1.,
+ warmup_iterations=0):
+ super().__init__()
+ self.criterion_projection = MPJPELoss(
+ loss_weight=projection_loss_weight)
+ self.criterion_bone = BoneLoss(
+ joint_parents, loss_weight=bone_loss_weight)
+ self.warmup_iterations = warmup_iterations
+ self.num_iterations = 0
+
+ @staticmethod
+ def project_joints(x, intrinsics):
+ """Project 3D joint coordinates to 2D image plane using camera
+ intrinsic parameters.
+
+ Args:
+ x (torch.Tensor[N, K, 3]): 3D joint coordinates.
+ intrinsics (torch.Tensor[N, 4] | torch.Tensor[N, 9]): Camera
+ intrinsics: f (2), c (2), k (3), p (2).
+ """
+ while intrinsics.dim() < x.dim():
+ intrinsics.unsqueeze_(1)
+ f = intrinsics[..., :2]
+ c = intrinsics[..., 2:4]
+ _x = torch.clamp(x[:, :, :2] / x[:, :, 2:], -1, 1)
+ if intrinsics.shape[-1] == 9:
+ k = intrinsics[..., 4:7]
+ p = intrinsics[..., 7:9]
+
+ r2 = torch.sum(_x[:, :, :2]**2, dim=-1, keepdim=True)
+ radial = 1 + torch.sum(
+ k * torch.cat((r2, r2**2, r2**3), dim=-1),
+ dim=-1,
+ keepdim=True)
+ tan = torch.sum(p * _x, dim=-1, keepdim=True)
+ _x = _x * (radial + tan) + p * r2
+ _x = f * _x + c
+ return _x
+
+ def forward(self, output, target):
+ losses = dict()
+
+ self.num_iterations += 1
+ if self.num_iterations <= self.warmup_iterations:
+ return losses
+
+ labeled_pose = output['labeled_pose']
+ unlabeled_pose = output['unlabeled_pose']
+ unlabeled_traj = output['unlabeled_traj']
+ unlabeled_target_2d = target['unlabeled_target_2d']
+ intrinsics = target['intrinsics']
+
+ # projection loss
+ unlabeled_output = unlabeled_pose + unlabeled_traj
+ unlabeled_output_2d = self.project_joints(unlabeled_output, intrinsics)
+ loss_proj = self.criterion_projection(unlabeled_output_2d,
+ unlabeled_target_2d, None)
+ losses['proj_loss'] = loss_proj
+
+ # bone loss
+ loss_bone = self.criterion_bone(unlabeled_pose, labeled_pose, None)
+ losses['bone_loss'] = loss_bone
+
+ return losses
diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f9105cb39cdfaba4f26d903b8bffbe05f4272a
--- /dev/null
+++ b/mmpose/models/necks/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .fmap_proc_neck import FeatureMapProcessor
+from .fpn import FPN
+from .gap_neck import GlobalAveragePooling
+from .posewarper_neck import PoseWarperNeck
+
+__all__ = [
+ 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor'
+]
diff --git a/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc b/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2c4ded6e13005fc78afa8a00d53825f5622984e
Binary files /dev/null and b/mmpose/models/necks/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..278aa6129866721b6c25c3513cc14b7cae35d8b4
Binary files /dev/null and b/mmpose/models/necks/__pycache__/fmap_proc_neck.cpython-38.pyc differ
diff --git a/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc b/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2807d62818c544d1dc35992223459fd32e01b9c
Binary files /dev/null and b/mmpose/models/necks/__pycache__/fpn.cpython-38.pyc differ
diff --git a/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3764a3b27eb7209f4a85a5695423861493cf2e1
Binary files /dev/null and b/mmpose/models/necks/__pycache__/gap_neck.cpython-38.pyc differ
diff --git a/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc b/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..815a18a3620173178e9a9ef38fb5c70f6c879bc2
Binary files /dev/null and b/mmpose/models/necks/__pycache__/posewarper_neck.cpython-38.pyc differ
diff --git a/mmpose/models/necks/fmap_proc_neck.py b/mmpose/models/necks/fmap_proc_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c3a4d7bf44ab07641a4968f143e17c19b24743b
--- /dev/null
+++ b/mmpose/models/necks/fmap_proc_neck.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from mmpose.models.utils.ops import resize
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class FeatureMapProcessor(nn.Module):
+ """A PyTorch module for selecting, concatenating, and rescaling feature
+ maps.
+
+ Args:
+ select_index (Optional[Union[int, Tuple[int]]], optional): Index or
+ indices of feature maps to select. Defaults to None, which means
+ all feature maps are used.
+ concat (bool, optional): Whether to concatenate the selected feature
+ maps. Defaults to False.
+ scale_factor (float, optional): The scaling factor to apply to the
+ feature maps. Defaults to 1.0.
+ apply_relu (bool, optional): Whether to apply ReLU on input feature
+ maps. Defaults to False.
+ align_corners (bool, optional): Whether to align corners when resizing
+ the feature maps. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ select_index: Optional[Union[int, Tuple[int]]] = None,
+ concat: bool = False,
+ scale_factor: float = 1.0,
+ apply_relu: bool = False,
+ align_corners: bool = False,
+ ):
+ super().__init__()
+
+ if isinstance(select_index, int):
+ select_index = (select_index, )
+ self.select_index = select_index
+ self.concat = concat
+
+ assert (
+ scale_factor > 0
+ ), f'the argument `scale_factor` must be positive, ' \
+ f'but got {scale_factor}'
+ self.scale_factor = scale_factor
+ self.apply_relu = apply_relu
+ self.align_corners = align_corners
+
+ def forward(self, inputs: Union[Tensor, Sequence[Tensor]]
+ ) -> Union[Tensor, List[Tensor]]:
+
+ if not isinstance(inputs, (tuple, list)):
+ sequential_input = False
+ inputs = [inputs]
+ else:
+ sequential_input = True
+
+ if self.select_index is not None:
+ inputs = [inputs[i] for i in self.select_index]
+
+ if self.concat:
+ inputs = self._concat(inputs)
+
+ if self.apply_relu:
+ inputs = [F.relu(x) for x in inputs]
+
+ if self.scale_factor != 1.0:
+ inputs = self._rescale(inputs)
+
+ if not sequential_input:
+ inputs = inputs[0]
+
+ return inputs
+
+ def _concat(self, inputs: Sequence[Tensor]) -> List[Tensor]:
+ size = inputs[0].shape[-2:]
+ resized_inputs = [
+ resize(
+ x,
+ size=size,
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ return [torch.cat(resized_inputs, dim=1)]
+
+ def _rescale(self, inputs: Sequence[Tensor]) -> List[Tensor]:
+ rescaled_inputs = [
+ resize(
+ x,
+ scale_factor=self.scale_factor,
+ mode='bilinear',
+ align_corners=self.align_corners,
+ ) for x in inputs
+ ]
+ return rescaled_inputs
diff --git a/mmpose/models/necks/fpn.py b/mmpose/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d3311bda792898dd1bc7ef9b9462db7b01ce05
--- /dev/null
+++ b/mmpose/models/necks/fpn.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmengine.model import xavier_init
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class FPN(nn.Module):
+ r"""Feature Pyramid Network.
+
+ This is an implementation of paper `Feature Pyramid Networks for Object
+ Detection `_.
+
+ Args:
+ in_channels (list[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, it is equivalent to `add_extra_convs='on_input'`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: dict(mode='nearest').
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super().__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ self.add_extra_convs = 'on_input'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ def init_weights(self):
+ """Initialize model weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ # fix runtime error of "+=" inplace operation in PyTorch 1.10
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return outs
diff --git a/mmpose/models/necks/gap_neck.py b/mmpose/models/necks/gap_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..58ce5d939ffdeb912a02e8b1823ab073cbc3d9e3
--- /dev/null
+++ b/mmpose/models/necks/gap_neck.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from mmpose.registry import MODELS
+
+
+@MODELS.register_module()
+class GlobalAveragePooling(nn.Module):
+ """Global Average Pooling neck.
+
+ Note that we use `view` to remove extra channel after pooling. We do not
+ use `squeeze` as it will also remove the batch dimension when the tensor
+ has a batch dimension of size 1, which can lead to unexpected errors.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.gap = nn.AdaptiveAvgPool2d((1, 1))
+
+ def init_weights(self):
+ pass
+
+ def forward(self, inputs):
+ """Forward function."""
+
+ if isinstance(inputs, tuple):
+ outs = tuple([self.gap(x) for x in inputs])
+ outs = tuple(
+ [out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
+ elif isinstance(inputs, list):
+ outs = [self.gap(x) for x in inputs]
+ outs = [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]
+ elif isinstance(inputs, torch.Tensor):
+ outs = self.gap(inputs)
+ outs = outs.view(inputs.size(0), -1)
+ else:
+ raise TypeError('neck inputs should be tuple or torch.tensor')
+ return outs
diff --git a/mmpose/models/necks/posewarper_neck.py b/mmpose/models/necks/posewarper_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..517fabd2e839878e7cf692c91adad450f432e8f0
--- /dev/null
+++ b/mmpose/models/necks/posewarper_neck.py
@@ -0,0 +1,329 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import constant_init, normal_init
+from mmengine.utils import digit_version
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.models.utils.ops import resize
+from mmpose.registry import MODELS
+from ..backbones.resnet import BasicBlock, Bottleneck
+
+try:
+ from mmcv.ops import DeformConv2d
+ has_mmcv_full = True
+except (ImportError, ModuleNotFoundError):
+ has_mmcv_full = False
+
+
+@MODELS.register_module()
+class PoseWarperNeck(nn.Module):
+ """PoseWarper neck.
+
+ `"Learning temporal pose estimation from sparsely-labeled videos"
+ `_.
+
+ Args:
+ in_channels (int): Number of input channels from backbone
+ out_channels (int): Number of output channels
+ inner_channels (int): Number of intermediate channels of the res block
+ deform_groups (int): Number of groups in the deformable conv
+ dilations (list|tuple): different dilations of the offset conv layers
+ trans_conv_kernel (int): the kernel of the trans conv layer, which is
+ used to get heatmap from the output of backbone. Default: 1
+ res_blocks_cfg (dict|None): config of residual blocks. If None,
+ use the default values. If not None, it should contain the
+ following keys:
+
+ - block (str): the type of residual block, Default: 'BASIC'.
+ - num_blocks (int): the number of blocks, Default: 20.
+
+ offsets_kernel (int): the kernel of offset conv layer.
+ deform_conv_kernel (int): the kernel of defomrable conv layer.
+ in_index (int|Sequence[int]): Input feature index. Default: 0
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ Default: None.
+
+ - 'resize_concat': Multiple feature maps will be resize to \
+ the same size as first one and than concat together. \
+ Usually used in FCN head of HRNet.
+ - 'multiple_select': Multiple feature maps will be bundle into \
+ a list and passed into decode head.
+ - None: Only one select feature map is allowed.
+
+ freeze_trans_layer (bool): Whether to freeze the transition layer
+ (stop grad and set eval mode). Default: True.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ im2col_step (int): the argument `im2col_step` in deformable conv,
+ Default: 80.
+ """
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+ minimum_mmcv_version = '1.3.17'
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ inner_channels,
+ deform_groups=17,
+ dilations=(3, 6, 12, 18, 24),
+ trans_conv_kernel=1,
+ res_blocks_cfg=None,
+ offsets_kernel=3,
+ deform_conv_kernel=3,
+ in_index=0,
+ input_transform=None,
+ freeze_trans_layer=True,
+ norm_eval=False,
+ im2col_step=80):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.inner_channels = inner_channels
+ self.deform_groups = deform_groups
+ self.dilations = dilations
+ self.trans_conv_kernel = trans_conv_kernel
+ self.res_blocks_cfg = res_blocks_cfg
+ self.offsets_kernel = offsets_kernel
+ self.deform_conv_kernel = deform_conv_kernel
+ self.in_index = in_index
+ self.input_transform = input_transform
+ self.freeze_trans_layer = freeze_trans_layer
+ self.norm_eval = norm_eval
+ self.im2col_step = im2col_step
+
+ identity_trans_layer = False
+
+ assert trans_conv_kernel in [0, 1, 3]
+ kernel_size = trans_conv_kernel
+ if kernel_size == 3:
+ padding = 1
+ elif kernel_size == 1:
+ padding = 0
+ else:
+ # 0 for Identity mapping.
+ identity_trans_layer = True
+
+ if identity_trans_layer:
+ self.trans_layer = nn.Identity()
+ else:
+ self.trans_layer = build_conv_layer(
+ cfg=dict(type='Conv2d'),
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ # build chain of residual blocks
+ if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict):
+ raise TypeError('res_blocks_cfg should be dict or None.')
+
+ if res_blocks_cfg is None:
+ block_type = 'BASIC'
+ num_blocks = 20
+ else:
+ block_type = res_blocks_cfg.get('block', 'BASIC')
+ num_blocks = res_blocks_cfg.get('num_blocks', 20)
+
+ block = self.blocks_dict[block_type]
+
+ res_layers = []
+ downsample = nn.Sequential(
+ build_conv_layer(
+ cfg=dict(type='Conv2d'),
+ in_channels=out_channels,
+ out_channels=inner_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ build_norm_layer(dict(type='BN'), inner_channels)[1])
+ res_layers.append(
+ block(
+ in_channels=out_channels,
+ out_channels=inner_channels,
+ downsample=downsample))
+
+ for _ in range(1, num_blocks):
+ res_layers.append(block(inner_channels, inner_channels))
+ self.offset_feats = nn.Sequential(*res_layers)
+
+ # build offset layers
+ self.num_offset_layers = len(dilations)
+ assert self.num_offset_layers > 0, 'Number of offset layers ' \
+ 'should be larger than 0.'
+
+ target_offset_channels = 2 * offsets_kernel**2 * deform_groups
+
+ offset_layers = [
+ build_conv_layer(
+ cfg=dict(type='Conv2d'),
+ in_channels=inner_channels,
+ out_channels=target_offset_channels,
+ kernel_size=offsets_kernel,
+ stride=1,
+ dilation=dilations[i],
+ padding=dilations[i],
+ bias=False,
+ ) for i in range(self.num_offset_layers)
+ ]
+ self.offset_layers = nn.ModuleList(offset_layers)
+
+ # build deformable conv layers
+ assert digit_version(mmcv.__version__) >= \
+ digit_version(self.minimum_mmcv_version), \
+ f'Current MMCV version: {mmcv.__version__}, ' \
+ f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \
+ f'https://github.com/open-mmlab/mmcv/issues/1440, ' \
+ f'Please install the latest MMCV.'
+
+ if has_mmcv_full:
+ deform_conv_layers = [
+ DeformConv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=deform_conv_kernel,
+ stride=1,
+ padding=int(deform_conv_kernel / 2) * dilations[i],
+ dilation=dilations[i],
+ deform_groups=deform_groups,
+ im2col_step=self.im2col_step,
+ ) for i in range(self.num_offset_layers)
+ ]
+ else:
+ raise ImportError('Please install the full version of mmcv '
+ 'to use `DeformConv2d`.')
+
+ self.deform_conv_layers = nn.ModuleList(deform_conv_layers)
+
+ self.freeze_layers()
+
+ def freeze_layers(self):
+ if self.freeze_trans_layer:
+ self.trans_layer.eval()
+
+ for param in self.trans_layer.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ elif isinstance(m, DeformConv2d):
+ filler = torch.zeros([
+ m.weight.size(0),
+ m.weight.size(1),
+ m.weight.size(2),
+ m.weight.size(3)
+ ],
+ dtype=torch.float32,
+ device=m.weight.device)
+ for k in range(m.weight.size(0)):
+ filler[k, k,
+ int(m.weight.size(2) / 2),
+ int(m.weight.size(3) / 2)] = 1.0
+ m.weight = torch.nn.Parameter(filler)
+ m.weight.requires_grad = True
+
+ # posewarper offset layer weight initialization
+ for m in self.offset_layers.modules():
+ constant_init(m, 0)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor] | Tensor): multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+ if not isinstance(inputs, list):
+ return inputs
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def forward(self, inputs, frame_weight):
+ assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \
+ 'should be list or tuple, even though the length is 1, ' \
+ 'for unified processing.'
+
+ output_heatmap = 0
+ if len(inputs) > 1:
+ inputs = [self._transform_inputs(input) for input in inputs]
+ inputs = [self.trans_layer(input) for input in inputs]
+
+ # calculate difference features
+ diff_features = [
+ self.offset_feats(inputs[0] - input) for input in inputs
+ ]
+
+ for i in range(len(inputs)):
+ if frame_weight[i] == 0:
+ continue
+ warped_heatmap = 0
+ for j in range(self.num_offset_layers):
+ offset = (self.offset_layers[j](diff_features[i]))
+ warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i],
+ offset)
+ warped_heatmap += warped_heatmap_tmp / \
+ self.num_offset_layers
+
+ output_heatmap += warped_heatmap * frame_weight[i]
+
+ else:
+ inputs = inputs[0]
+ inputs = self._transform_inputs(inputs)
+ inputs = self.trans_layer(inputs)
+
+ num_frames = len(frame_weight)
+ batch_size = inputs.size(0) // num_frames
+ ref_x = inputs[:batch_size]
+ ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1)
+
+ offset_features = self.offset_feats(ref_x_tiled - inputs)
+
+ warped_heatmap = 0
+ for j in range(self.num_offset_layers):
+ offset = self.offset_layers[j](offset_features)
+
+ warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset)
+ warped_heatmap += warped_heatmap_tmp / self.num_offset_layers
+
+ for i in range(num_frames):
+ if frame_weight[i] == 0:
+ continue
+ output_heatmap += warped_heatmap[i * batch_size:(i + 1) *
+ batch_size] * frame_weight[i]
+
+ return output_heatmap
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self.freeze_layers()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpose/models/pose_estimators/__init__.py b/mmpose/models/pose_estimators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ead1a979e5a8c38ae98240b181e425ab7eeed35
--- /dev/null
+++ b/mmpose/models/pose_estimators/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bottomup import BottomupPoseEstimator
+from .topdown import TopdownPoseEstimator
+
+__all__ = ['TopdownPoseEstimator', 'BottomupPoseEstimator']
diff --git a/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a039392f447e0ebd7daf846a51f579fb2b1e5da
Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c6404a265eed582ffe953e0c9be756c5db17e14
Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/base.cpython-38.pyc differ
diff --git a/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4fbfe2e96c4b850e0e69a4d579dee03f27ee3db
Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/bottomup.cpython-38.pyc differ
diff --git a/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc b/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d1cedc346a0e34064780c2f66e497ea75de9491
Binary files /dev/null and b/mmpose/models/pose_estimators/__pycache__/topdown.cpython-38.pyc differ
diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d60de93a245a37093a723408fc25bc5370b448
--- /dev/null
+++ b/mmpose/models/pose_estimators/base.py
@@ -0,0 +1,210 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from typing import Tuple, Union
+
+import torch
+from mmengine.model import BaseModel
+from torch import Tensor
+
+from mmpose.datasets.datasets.utils import parse_pose_metainfo
+from mmpose.models.utils import check_and_update_config
+from mmpose.registry import MODELS
+from mmpose.utils.typing import (ConfigType, ForwardResults, OptConfigType,
+ Optional, OptMultiConfig, OptSampleList,
+ SampleList)
+
+
+class BasePoseEstimator(BaseModel, metaclass=ABCMeta):
+ """Base class for pose estimators.
+
+ Args:
+ data_preprocessor (dict | ConfigDict, optional): The pre-processing
+ config of :class:`BaseDataPreprocessor`. Defaults to ``None``
+ init_cfg (dict | ConfigDict): The model initialization config.
+ Defaults to ``None``
+ metainfo (dict): Meta information for dataset, such as keypoints
+ definition and properties. If set, the metainfo of the input data
+ batch will be overridden. For more details, please refer to
+ https://mmpose.readthedocs.io/en/latest/user_guides/
+ prepare_datasets.html#create-a-custom-dataset-info-
+ config-file-for-the-dataset. Defaults to ``None``
+ """
+ _version = 2
+
+ def __init__(self,
+ backbone: ConfigType,
+ neck: OptConfigType = None,
+ head: OptConfigType = None,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ init_cfg: OptMultiConfig = None,
+ metainfo: Optional[dict] = None):
+ super().__init__(
+ data_preprocessor=data_preprocessor, init_cfg=init_cfg)
+ self.metainfo = self._load_metainfo(metainfo)
+
+ self.backbone = MODELS.build(backbone)
+
+ # the PR #2108 and #2126 modified the interface of neck and head.
+ # The following function automatically detects outdated
+ # configurations and updates them accordingly, while also providing
+ # clear and concise information on the changes made.
+ neck, head = check_and_update_config(neck, head)
+
+ if neck is not None:
+ self.neck = MODELS.build(neck)
+
+ if head is not None:
+ self.head = MODELS.build(head)
+
+ self.train_cfg = train_cfg if train_cfg else {}
+ self.test_cfg = test_cfg if test_cfg else {}
+
+ # Register the hook to automatically convert old version state dicts
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ @property
+ def with_neck(self) -> bool:
+ """bool: whether the pose estimator has a neck."""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ @property
+ def with_head(self) -> bool:
+ """bool: whether the pose estimator has a head."""
+ return hasattr(self, 'head') and self.head is not None
+
+ @staticmethod
+ def _load_metainfo(metainfo: dict = None) -> dict:
+ """Collect meta information from the dictionary of meta.
+
+ Args:
+ metainfo (dict): Raw data of pose meta information.
+
+ Returns:
+ dict: Parsed meta information.
+ """
+
+ if metainfo is None:
+ return None
+
+ if not isinstance(metainfo, dict):
+ raise TypeError(
+ f'metainfo should be a dict, but got {type(metainfo)}')
+
+ metainfo = parse_pose_metainfo(metainfo)
+ return metainfo
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: OptSampleList,
+ mode: str = 'tensor') -> ForwardResults:
+ """The unified entry for a forward process in both training and test.
+
+ The method should accept three modes: 'tensor', 'predict' and 'loss':
+
+ - 'tensor': Forward the whole network and return tensor or tuple of
+ tensor without any post-processing, same as a common nn.Module.
+ - 'predict': Forward and return the predictions, which are fully
+ processed to a list of :obj:`PoseDataSample`.
+ - 'loss': Forward and return a dict of losses according to the given
+ inputs and data samples.
+
+ Note that this method doesn't handle neither back propagation nor
+ optimizer updating, which are done in the :meth:`train_step`.
+
+ Args:
+ inputs (torch.Tensor): The input tensor with shape
+ (N, C, ...) in general
+ data_samples (list[:obj:`PoseDataSample`], optional): The
+ annotation of every sample. Defaults to ``None``
+ mode (str): Set the forward mode and return value type. Defaults
+ to ``'tensor'``
+
+ Returns:
+ The return type depends on ``mode``.
+
+ - If ``mode='tensor'``, return a tensor or a tuple of tensors
+ - If ``mode='predict'``, return a list of :obj:``PoseDataSample``
+ that contains the pose predictions
+ - If ``mode='loss'``, return a dict of tensor(s) which is the loss
+ function value
+ """
+ if mode == 'loss':
+ return self.loss(inputs, data_samples)
+ elif mode == 'predict':
+ # use customed metainfo to override the default metainfo
+ if self.metainfo is not None:
+ for data_sample in data_samples:
+ data_sample.set_metainfo(self.metainfo)
+ return self.predict(inputs, data_samples)
+ elif mode == 'tensor':
+ return self._forward(inputs)
+ else:
+ raise RuntimeError(f'Invalid mode "{mode}". '
+ 'Only supports loss, predict and tensor mode.')
+
+ @abstractmethod
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ @abstractmethod
+ def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing."""
+
+ def _forward(self,
+ inputs: Tensor,
+ data_samples: OptSampleList = None
+ ) -> Union[Tensor, Tuple[Tensor]]:
+ """Network forward process. Usually includes backbone, neck and head
+ forward without any post-processing.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+
+ Returns:
+ Union[Tensor | Tuple[Tensor]]: forward output of the network.
+ """
+
+ x = self.extract_feat(inputs)
+ if self.with_head:
+ x = self.head.forward(x)
+
+ return x
+
+ def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
+ """Extract features.
+
+ Args:
+ inputs (Tensor): Image tensor with shape (N, C, H ,W).
+
+ Returns:
+ tuple[Tensor]: Multi-level features that may have various
+ resolutions.
+ """
+ x = self.backbone(inputs)
+ if self.with_neck:
+ x = self.neck(x)
+
+ return x
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert old-version state dict of
+ :class:`TopdownHeatmapSimpleHead` (before MMPose v1.0.0) to a
+ compatible format of :class:`HeatmapHead`.
+
+ The hook will be automatically registered during initialization.
+ """
+ version = local_meta.get('version', None)
+ if version and version >= self._version:
+ return
+
+ # convert old-version state dict
+ keys = list(state_dict.keys())
+ for k in keys:
+ if 'keypoint_head' in k:
+ v = state_dict.pop(k)
+ k = k.replace('keypoint_head', 'head')
+ state_dict[k] = v
diff --git a/mmpose/models/pose_estimators/bottomup.py b/mmpose/models/pose_estimators/bottomup.py
new file mode 100644
index 0000000000000000000000000000000000000000..5400f2478e411bbf39b875ad2c8bfd6532ffa4b8
--- /dev/null
+++ b/mmpose/models/pose_estimators/bottomup.py
@@ -0,0 +1,178 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import zip_longest
+from typing import List, Optional, Union
+
+from mmengine.utils import is_list_of
+from torch import Tensor
+
+from mmpose.registry import MODELS
+from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
+ OptMultiConfig, PixelDataList, SampleList)
+from .base import BasePoseEstimator
+
+
+@MODELS.register_module()
+class BottomupPoseEstimator(BasePoseEstimator):
+ """Base class for bottom-up pose estimators.
+
+ Args:
+ backbone (dict): The backbone config
+ neck (dict, optional): The neck config. Defaults to ``None``
+ head (dict, optional): The head config. Defaults to ``None``
+ train_cfg (dict, optional): The runtime config for training process.
+ Defaults to ``None``
+ test_cfg (dict, optional): The runtime config for testing process.
+ Defaults to ``None``
+ data_preprocessor (dict, optional): The data preprocessing config to
+ build the instance of :class:`BaseDataPreprocessor`. Defaults to
+ ``None``.
+ init_cfg (dict, optional): The config to control the initialization.
+ Defaults to ``None``
+ """
+
+ def __init__(self,
+ backbone: ConfigType,
+ neck: OptConfigType = None,
+ head: OptConfigType = None,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ init_cfg: OptMultiConfig = None):
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ head=head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ data_preprocessor=data_preprocessor,
+ init_cfg=init_cfg)
+
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples.
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+ feats = self.extract_feat(inputs)
+
+ losses = dict()
+
+ if self.with_head:
+ losses.update(
+ self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
+
+ return losses
+
+ def predict(self, inputs: Union[Tensor, List[Tensor]],
+ data_samples: SampleList) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ inputs (Tensor | List[Tensor]): Input image in tensor or image
+ pyramid as a list of tensors. Each tensor is in shape
+ [B, C, H, W]
+ data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+
+ Returns:
+ list[:obj:`PoseDataSample`]: The pose estimation results of the
+ input images. The return value is `PoseDataSample` instances with
+ ``pred_instances`` and ``pred_fields``(optional) field , and
+ ``pred_instances`` usually contains the following keys:
+
+ - keypoints (Tensor): predicted keypoint coordinates in shape
+ (num_instances, K, D) where K is the keypoint number and D
+ is the keypoint dimension
+ - keypoint_scores (Tensor): predicted keypoint scores in shape
+ (num_instances, K)
+ """
+ assert self.with_head, (
+ 'The model must have head to perform prediction.')
+
+ multiscale_test = self.test_cfg.get('multiscale_test', False)
+ flip_test = self.test_cfg.get('flip_test', False)
+
+ # enable multi-scale test
+ aug_scales = data_samples[0].metainfo.get('aug_scales', None)
+ if multiscale_test:
+ assert isinstance(aug_scales, list)
+ assert is_list_of(inputs, Tensor)
+ # `inputs` includes images in original and augmented scales
+ assert len(inputs) == len(aug_scales) + 1
+ else:
+ assert isinstance(inputs, Tensor)
+ # single-scale test
+ inputs = [inputs]
+
+ feats = []
+ for _inputs in inputs:
+ if flip_test:
+ _feats_orig = self.extract_feat(_inputs)
+ _feats_flip = self.extract_feat(_inputs.flip(-1))
+ _feats = [_feats_orig, _feats_flip]
+ else:
+ _feats = self.extract_feat(_inputs)
+
+ feats.append(_feats)
+
+ if not multiscale_test:
+ feats = feats[0]
+
+ preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
+
+ if isinstance(preds, tuple):
+ batch_pred_instances, batch_pred_fields = preds
+ else:
+ batch_pred_instances = preds
+ batch_pred_fields = None
+
+ results = self.add_pred_to_datasample(batch_pred_instances,
+ batch_pred_fields, data_samples)
+
+ return results
+
+ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
+ batch_pred_fields: Optional[PixelDataList],
+ batch_data_samples: SampleList) -> SampleList:
+ """Add predictions into data samples.
+
+ Args:
+ batch_pred_instances (List[InstanceData]): The predicted instances
+ of the input data batch
+ batch_pred_fields (List[PixelData], optional): The predicted
+ fields (e.g. heatmaps) of the input batch
+ batch_data_samples (List[PoseDataSample]): The input data batch
+
+ Returns:
+ List[PoseDataSample]: A list of data samples where the predictions
+ are stored in the ``pred_instances`` field of each data sample.
+ The length of the list is the batch size when ``merge==False``, or
+ 1 when ``merge==True``.
+ """
+ assert len(batch_pred_instances) == len(batch_data_samples)
+ if batch_pred_fields is None:
+ batch_pred_fields = []
+
+ for pred_instances, pred_fields, data_sample in zip_longest(
+ batch_pred_instances, batch_pred_fields, batch_data_samples):
+
+ # convert keypoint coordinates from input space to image space
+ input_size = data_sample.metainfo['input_size']
+ input_center = data_sample.metainfo['input_center']
+ input_scale = data_sample.metainfo['input_scale']
+
+ pred_instances.keypoints = pred_instances.keypoints / input_size \
+ * input_scale + input_center - 0.5 * input_scale
+
+ data_sample.pred_instances = pred_instances
+
+ if pred_fields is not None:
+ data_sample.pred_fields = pred_fields
+
+ return batch_data_samples
diff --git a/mmpose/models/pose_estimators/topdown.py b/mmpose/models/pose_estimators/topdown.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b332893f0c4e051084a9eb86addf75433c0534
--- /dev/null
+++ b/mmpose/models/pose_estimators/topdown.py
@@ -0,0 +1,182 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import zip_longest
+from typing import Optional
+
+from torch import Tensor
+
+from mmpose.registry import MODELS
+from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
+ OptMultiConfig, PixelDataList, SampleList)
+from .base import BasePoseEstimator
+
+
+@MODELS.register_module()
+class TopdownPoseEstimator(BasePoseEstimator):
+ """Base class for top-down pose estimators.
+
+ Args:
+ backbone (dict): The backbone config
+ neck (dict, optional): The neck config. Defaults to ``None``
+ head (dict, optional): The head config. Defaults to ``None``
+ train_cfg (dict, optional): The runtime config for training process.
+ Defaults to ``None``
+ test_cfg (dict, optional): The runtime config for testing process.
+ Defaults to ``None``
+ data_preprocessor (dict, optional): The data preprocessing config to
+ build the instance of :class:`BaseDataPreprocessor`. Defaults to
+ ``None``
+ init_cfg (dict, optional): The config to control the initialization.
+ Defaults to ``None``
+ metainfo (dict): Meta information for dataset, such as keypoints
+ definition and properties. If set, the metainfo of the input data
+ batch will be overridden. For more details, please refer to
+ https://mmpose.readthedocs.io/en/latest/user_guides/
+ prepare_datasets.html#create-a-custom-dataset-info-
+ config-file-for-the-dataset. Defaults to ``None``
+ """
+
+ def __init__(self,
+ backbone: ConfigType,
+ neck: OptConfigType = None,
+ head: OptConfigType = None,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ init_cfg: OptMultiConfig = None,
+ metainfo: Optional[dict] = None):
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ head=head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ data_preprocessor=data_preprocessor,
+ init_cfg=init_cfg,
+ metainfo=metainfo)
+
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples.
+
+ Returns:
+ dict: A dictionary of losses.
+ """
+ feats = self.extract_feat(inputs)
+
+ losses = dict()
+
+ if self.with_head:
+ losses.update(
+ self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
+
+ return losses
+
+ def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W)
+ data_samples (List[:obj:`PoseDataSample`]): The batch
+ data samples
+
+ Returns:
+ list[:obj:`PoseDataSample`]: The pose estimation results of the
+ input images. The return value is `PoseDataSample` instances with
+ ``pred_instances`` and ``pred_fields``(optional) field , and
+ ``pred_instances`` usually contains the following keys:
+
+ - keypoints (Tensor): predicted keypoint coordinates in shape
+ (num_instances, K, D) where K is the keypoint number and D
+ is the keypoint dimension
+ - keypoint_scores (Tensor): predicted keypoint scores in shape
+ (num_instances, K)
+ """
+ assert self.with_head, (
+ 'The model must have head to perform prediction.')
+
+ if self.test_cfg.get('flip_test', False):
+ _feats = self.extract_feat(inputs)
+ _feats_flip = self.extract_feat(inputs.flip(-1))
+ feats = [_feats, _feats_flip]
+ else:
+ feats = self.extract_feat(inputs)
+
+ preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
+
+ if isinstance(preds, tuple):
+ batch_pred_instances, batch_pred_fields = preds
+ else:
+ batch_pred_instances = preds
+ batch_pred_fields = None
+
+ results = self.add_pred_to_datasample(batch_pred_instances,
+ batch_pred_fields, data_samples)
+
+ return results
+
+ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
+ batch_pred_fields: Optional[PixelDataList],
+ batch_data_samples: SampleList) -> SampleList:
+ """Add predictions into data samples.
+
+ Args:
+ batch_pred_instances (List[InstanceData]): The predicted instances
+ of the input data batch
+ batch_pred_fields (List[PixelData], optional): The predicted
+ fields (e.g. heatmaps) of the input batch
+ batch_data_samples (List[PoseDataSample]): The input data batch
+
+ Returns:
+ List[PoseDataSample]: A list of data samples where the predictions
+ are stored in the ``pred_instances`` field of each data sample.
+ """
+ assert len(batch_pred_instances) == len(batch_data_samples)
+ if batch_pred_fields is None:
+ batch_pred_fields = []
+ output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
+ None)
+
+ for pred_instances, pred_fields, data_sample in zip_longest(
+ batch_pred_instances, batch_pred_fields, batch_data_samples):
+
+ gt_instances = data_sample.gt_instances
+
+ # convert keypoint coordinates from input space to image space
+ bbox_centers = gt_instances.bbox_centers
+ bbox_scales = gt_instances.bbox_scales
+ input_size = data_sample.metainfo['input_size']
+
+ pred_instances.keypoints = pred_instances.keypoints / input_size \
+ * bbox_scales + bbox_centers - 0.5 * bbox_scales
+
+ if output_keypoint_indices is not None:
+ # select output keypoints with given indices
+ num_keypoints = pred_instances.keypoints.shape[1]
+ for key, value in pred_instances.all_items():
+ if key.startswith('keypoint'):
+ pred_instances.set_field(
+ value[:, output_keypoint_indices], key)
+
+ # add bbox information into pred_instances
+ pred_instances.bboxes = gt_instances.bboxes
+ pred_instances.bbox_scores = gt_instances.bbox_scores
+
+ data_sample.pred_instances = pred_instances
+
+ if pred_fields is not None:
+ if output_keypoint_indices is not None:
+ # select output heatmap channels with keypoint indices
+ # when the number of heatmap channel matches num_keypoints
+ for key, value in pred_fields.all_items():
+ if value.shape[0] != num_keypoints:
+ continue
+ pred_fields.set_field(value[output_keypoint_indices],
+ key)
+ data_sample.pred_fields = pred_fields
+
+ return batch_data_samples
diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22d8a89b41da0be83756f2f7a2e1a6194eb7cd5a
--- /dev/null
+++ b/mmpose/models/utils/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .check_and_update_config import check_and_update_config
+from .ckpt_convert import pvt_convert
+from .rtmcc_block import RTMCCBlock, rope
+from .transformer import PatchEmbed, nchw_to_nlc, nlc_to_nchw
+
+__all__ = [
+ 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 'RTMCCBlock',
+ 'rope', 'check_and_update_config'
+]
diff --git a/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2dcf32a76e1b8c138f5b12c0af80e543e9a96aba
Binary files /dev/null and b/mmpose/models/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc b/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95204c2631a0b8893555c556a72136404950711e
Binary files /dev/null and b/mmpose/models/utils/__pycache__/check_and_update_config.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc b/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79fa68fc47c6bb63a1edcc2657c804ff9c10728f
Binary files /dev/null and b/mmpose/models/utils/__pycache__/ckpt_convert.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/ops.cpython-38.pyc b/mmpose/models/utils/__pycache__/ops.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46d89f44788765d578d4f5c1a5705fa742ba2f5b
Binary files /dev/null and b/mmpose/models/utils/__pycache__/ops.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc b/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fc7f095b108530daced9690f47553b4e8fb133b
Binary files /dev/null and b/mmpose/models/utils/__pycache__/realnvp.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc b/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dca329fdab6ee0a888080a0999887f3088f3b767
Binary files /dev/null and b/mmpose/models/utils/__pycache__/regularizations.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc b/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b29af37d072c3993be24a19958ce30473ef6bb59
Binary files /dev/null and b/mmpose/models/utils/__pycache__/rtmcc_block.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc b/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..932e3120b45f06de03258c467d50f0b8ba64d025
Binary files /dev/null and b/mmpose/models/utils/__pycache__/transformer.cpython-38.pyc differ
diff --git a/mmpose/models/utils/__pycache__/tta.cpython-38.pyc b/mmpose/models/utils/__pycache__/tta.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76f305d1dc70386039b5dbee8309ce974258e90f
Binary files /dev/null and b/mmpose/models/utils/__pycache__/tta.cpython-38.pyc differ
diff --git a/mmpose/models/utils/check_and_update_config.py b/mmpose/models/utils/check_and_update_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cd1efa39b584a08055d470343549349907c1a5c
--- /dev/null
+++ b/mmpose/models/utils/check_and_update_config.py
@@ -0,0 +1,230 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Optional, Tuple, Union
+
+from mmengine.config import Config, ConfigDict
+from mmengine.dist import master_only
+from mmengine.logging import MMLogger
+
+ConfigType = Union[Config, ConfigDict]
+
+
+def process_input_transform(input_transform: str, head: Dict, head_new: Dict,
+ head_deleted_dict: Dict, head_append_dict: Dict,
+ neck_new: Dict, input_index: Tuple[int],
+ align_corners: bool) -> None:
+ """Process the input_transform field and update head and neck
+ dictionaries."""
+ if input_transform == 'resize_concat':
+ in_channels = head_new.pop('in_channels')
+ head_deleted_dict['in_channels'] = str(in_channels)
+ in_channels = sum([in_channels[i] for i in input_index])
+ head_new['in_channels'] = in_channels
+ head_append_dict['in_channels'] = str(in_channels)
+
+ neck_new.update(
+ dict(
+ type='FeatureMapProcessor',
+ concat=True,
+ select_index=input_index,
+ ))
+ if align_corners:
+ neck_new['align_corners'] = align_corners
+
+ elif input_transform == 'select':
+ if input_index != (-1, ):
+ neck_new.update(
+ dict(type='FeatureMapProcessor', select_index=input_index))
+ if isinstance(head['in_channels'], tuple):
+ in_channels = head_new.pop('in_channels')
+ head_deleted_dict['in_channels'] = str(in_channels)
+ if isinstance(input_index, int):
+ in_channels = in_channels[input_index]
+ else:
+ in_channels = tuple([in_channels[i] for i in input_index])
+ head_new['in_channels'] = in_channels
+ head_append_dict['in_channels'] = str(in_channels)
+ if align_corners:
+ neck_new['align_corners'] = align_corners
+
+ else:
+ raise ValueError(f'model.head get invalid value for argument '
+ f'input_transform: {input_transform}')
+
+
+def process_extra_field(extra: Dict, head_new: Dict, head_deleted_dict: Dict,
+ head_append_dict: Dict, neck_new: Dict) -> None:
+ """Process the extra field and update head and neck dictionaries."""
+ head_deleted_dict['extra'] = 'dict('
+ for key, value in extra.items():
+ head_deleted_dict['extra'] += f'{key}={value},'
+ head_deleted_dict['extra'] = head_deleted_dict['extra'][:-1] + ')'
+ if 'final_conv_kernel' in extra:
+ kernel_size = extra['final_conv_kernel']
+ if kernel_size > 1:
+ padding = kernel_size // 2
+ head_new['final_layer'] = dict(
+ kernel_size=kernel_size, padding=padding)
+ head_append_dict[
+ 'final_layer'] = f'dict(kernel_size={kernel_size}, ' \
+ f'padding={padding})'
+ else:
+ head_new['final_layer'] = dict(kernel_size=kernel_size)
+ head_append_dict[
+ 'final_layer'] = f'dict(kernel_size={kernel_size})'
+ if 'upsample' in extra:
+ neck_new.update(
+ dict(
+ type='FeatureMapProcessor',
+ scale_factor=float(extra['upsample']),
+ apply_relu=True,
+ ))
+
+
+def process_has_final_layer(has_final_layer: bool, head_new: Dict,
+ head_deleted_dict: Dict,
+ head_append_dict: Dict) -> None:
+ """Process the has_final_layer field and update the head dictionary."""
+ head_deleted_dict['has_final_layer'] = str(has_final_layer)
+ if not has_final_layer:
+ if 'final_layer' not in head_new:
+ head_new['final_layer'] = None
+ head_append_dict['final_layer'] = 'None'
+
+
+def check_and_update_config(neck: Optional[ConfigType],
+ head: ConfigType) -> Tuple[Optional[Dict], Dict]:
+ """Check and update the configuration of the head and neck components.
+ Args:
+ neck (Optional[ConfigType]): Configuration for the neck component.
+ head (ConfigType): Configuration for the head component.
+
+ Returns:
+ Tuple[Optional[Dict], Dict]: Updated configurations for the neck
+ and head components.
+ """
+ head_new, neck_new = head.copy(), neck.copy() if isinstance(neck,
+ dict) else {}
+ head_deleted_dict, head_append_dict = {}, {}
+
+ if 'input_transform' in head:
+ input_transform = head_new.pop('input_transform')
+ head_deleted_dict['input_transform'] = f'\'{input_transform}\''
+ else:
+ input_transform = 'select'
+
+ if 'input_index' in head:
+ input_index = head_new.pop('input_index')
+ head_deleted_dict['input_index'] = str(input_index)
+ else:
+ input_index = (-1, )
+
+ if 'align_corners' in head:
+ align_corners = head_new.pop('align_corners')
+ head_deleted_dict['align_corners'] = str(align_corners)
+ else:
+ align_corners = False
+
+ process_input_transform(input_transform, head, head_new, head_deleted_dict,
+ head_append_dict, neck_new, input_index,
+ align_corners)
+
+ if 'extra' in head:
+ extra = head_new.pop('extra')
+ process_extra_field(extra, head_new, head_deleted_dict,
+ head_append_dict, neck_new)
+
+ if 'has_final_layer' in head:
+ has_final_layer = head_new.pop('has_final_layer')
+ process_has_final_layer(has_final_layer, head_new, head_deleted_dict,
+ head_append_dict)
+
+ display_modifications(head_deleted_dict, head_append_dict, neck_new)
+
+ neck_new = neck_new if len(neck_new) else None
+ return neck_new, head_new
+
+
+@master_only
+def display_modifications(head_deleted_dict: Dict, head_append_dict: Dict,
+ neck: Dict) -> None:
+ """Display the modifications made to the head and neck configurations.
+
+ Args:
+ head_deleted_dict (Dict): Dictionary of deleted fields in the head.
+ head_append_dict (Dict): Dictionary of appended fields in the head.
+ neck (Dict): Updated neck configuration.
+ """
+ if len(head_deleted_dict) + len(head_append_dict) == 0:
+ return
+
+ old_model_info, new_model_info = build_model_info(head_deleted_dict,
+ head_append_dict, neck)
+
+ total_info = '\nThe config you are using is outdated. '\
+ 'The following section of the config:\n```\n'
+ total_info += old_model_info
+ total_info += '```\nshould be updated to\n```\n'
+ total_info += new_model_info
+ total_info += '```\nFor more information, please refer to '\
+ 'https://mmpose.readthedocs.io/en/latest/' \
+ 'guide_to_framework.html#step3-model'
+
+ logger: MMLogger = MMLogger.get_current_instance()
+ logger.warning(total_info)
+
+
+def build_model_info(head_deleted_dict: Dict, head_append_dict: Dict,
+ neck: Dict) -> Tuple[str, str]:
+ """Build the old and new model information strings.
+ Args:
+ head_deleted_dict (Dict): Dictionary of deleted fields in the head.
+ head_append_dict (Dict): Dictionary of appended fields in the head.
+ neck (Dict): Updated neck configuration.
+
+ Returns:
+ Tuple[str, str]: Old and new model information strings.
+ """
+ old_head_info = build_head_info(head_deleted_dict)
+ new_head_info = build_head_info(head_append_dict)
+ neck_info = build_neck_info(neck)
+
+ old_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' + old_head_info
+ new_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' \
+ + neck_info + new_head_info
+
+ return old_model_info, new_model_info
+
+
+def build_head_info(head_dict: Dict) -> str:
+ """Build the head information string.
+
+ Args:
+ head_dict (Dict): Dictionary of fields in the head configuration.
+ Returns:
+ str: Head information string.
+ """
+ head_info = ' ' * 4 + 'head=dict(\n'
+ for key, value in head_dict.items():
+ head_info += ' ' * 8 + f'{key}={value},\n'
+ head_info += ' ' * 8 + '...),\n'
+ return head_info
+
+
+def build_neck_info(neck: Dict) -> str:
+ """Build the neck information string.
+ Args:
+ neck (Dict): Updated neck configuration.
+
+ Returns:
+ str: Neck information string.
+ """
+ if len(neck) > 0:
+ neck = neck.copy()
+ neck_info = ' ' * 4 + 'neck=dict(\n' + ' ' * 8 + \
+ f'type=\'{neck.pop("type")}\',\n'
+ for key, value in neck.items():
+ neck_info += ' ' * 8 + f'{key}={str(value)},\n'
+ neck_info += ' ' * 4 + '),\n'
+ else:
+ neck_info = ''
+ return neck_info
diff --git a/mmpose/models/utils/ckpt_convert.py b/mmpose/models/utils/ckpt_convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..05f5cdb4a3cdf32ac2b6b7a8888c5a772e582f14
--- /dev/null
+++ b/mmpose/models/utils/ckpt_convert.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# This script consists of several convert functions which
+# can modify the weights of model in original repo to be
+# pre-trained weights.
+
+from collections import OrderedDict
+
+import torch
+
+
+def pvt_convert(ckpt):
+ new_ckpt = OrderedDict()
+ # Process the concat between q linear weights and kv linear weights
+ use_abs_pos_embed = False
+ use_conv_ffn = False
+ for k in ckpt.keys():
+ if k.startswith('pos_embed'):
+ use_abs_pos_embed = True
+ if k.find('dwconv') >= 0:
+ use_conv_ffn = True
+ for k, v in ckpt.items():
+ if k.startswith('head'):
+ continue
+ if k.startswith('norm.'):
+ continue
+ if k.startswith('cls_token'):
+ continue
+ if k.startswith('pos_embed'):
+ stage_i = int(k.replace('pos_embed', ''))
+ new_k = k.replace(f'pos_embed{stage_i}',
+ f'layers.{stage_i - 1}.1.0.pos_embed')
+ if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7
+ new_v = v[:, 1:, :] # remove cls token
+ else:
+ new_v = v
+ elif k.startswith('patch_embed'):
+ stage_i = int(k.split('.')[0].replace('patch_embed', ''))
+ new_k = k.replace(f'patch_embed{stage_i}',
+ f'layers.{stage_i - 1}.0')
+ new_v = v
+ if 'proj.' in new_k:
+ new_k = new_k.replace('proj.', 'projection.')
+ elif k.startswith('block'):
+ stage_i = int(k.split('.')[0].replace('block', ''))
+ layer_i = int(k.split('.')[1])
+ new_layer_i = layer_i + use_abs_pos_embed
+ new_k = k.replace(f'block{stage_i}.{layer_i}',
+ f'layers.{stage_i - 1}.1.{new_layer_i}')
+ new_v = v
+ if 'attn.q.' in new_k:
+ sub_item_k = k.replace('q.', 'kv.')
+ new_k = new_k.replace('q.', 'attn.in_proj_')
+ new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
+ elif 'attn.kv.' in new_k:
+ continue
+ elif 'attn.proj.' in new_k:
+ new_k = new_k.replace('proj.', 'attn.out_proj.')
+ elif 'attn.sr.' in new_k:
+ new_k = new_k.replace('sr.', 'sr.')
+ elif 'mlp.' in new_k:
+ string = f'{new_k}-'
+ new_k = new_k.replace('mlp.', 'ffn.layers.')
+ if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
+ new_v = v.reshape((*v.shape, 1, 1))
+ new_k = new_k.replace('fc1.', '0.')
+ new_k = new_k.replace('dwconv.dwconv.', '1.')
+ if use_conv_ffn:
+ new_k = new_k.replace('fc2.', '4.')
+ else:
+ new_k = new_k.replace('fc2.', '3.')
+ string += f'{new_k} {v.shape}-{new_v.shape}'
+ elif k.startswith('norm'):
+ stage_i = int(k[4])
+ new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')
+ new_v = v
+ else:
+ new_k = k
+ new_v = v
+ new_ckpt[new_k] = new_v
+
+ return new_ckpt
diff --git a/mmpose/models/utils/geometry.py b/mmpose/models/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ceadaec30cd2c9bb3fbada132e1ea674f2e8754
--- /dev/null
+++ b/mmpose/models/utils/geometry.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn import functional as F
+
+
+def rot6d_to_rotmat(x):
+ """Convert 6D rotation representation to 3x3 rotation matrix.
+
+ Based on Zhou et al., "On the Continuity of Rotation
+ Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ x = x.view(-1, 3, 2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+
+def batch_rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion
+ -- size = [B, 3, 3]
+ """
+ l2norm = torch.norm(theta + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(l2norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
+ return quat_to_rotmat(quat)
+
+
+def quat_to_rotmat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion
+ -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\
+ norm_quat[:, 2], norm_quat[:, 3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+
+ rotMat = torch.stack([
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
+ w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
+ w2 - x2 - y2 + z2
+ ],
+ dim=1).view(B, 3, 3)
+ return rotMat
diff --git a/mmpose/models/utils/ops.py b/mmpose/models/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c94352647178e53e618e8fc7cba36fa7a9c0ad2
--- /dev/null
+++ b/mmpose/models/utils/ops.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+from torch.nn import functional as F
+
+
+def resize(input: torch.Tensor,
+ size: Optional[Union[Tuple[int, int], torch.Size]] = None,
+ scale_factor: Optional[float] = None,
+ mode: str = 'nearest',
+ align_corners: Optional[bool] = None,
+ warning: bool = True) -> torch.Tensor:
+ """Resize a given input tensor using specified size or scale_factor.
+
+ Args:
+ input (torch.Tensor): The input tensor to be resized.
+ size (Optional[Union[Tuple[int, int], torch.Size]]): The desired
+ output size. Defaults to None.
+ scale_factor (Optional[float]): The scaling factor for resizing.
+ Defaults to None.
+ mode (str): The interpolation mode. Defaults to 'nearest'.
+ align_corners (Optional[bool]): Determines whether to align the
+ corners when using certain interpolation modes. Defaults to None.
+ warning (bool): Whether to display a warning when the input and
+ output sizes are not ideal for alignment. Defaults to True.
+
+ Returns:
+ torch.Tensor: The resized tensor.
+ """
+ # Check if a warning should be displayed regarding input and output sizes
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would be more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+
+ # Convert torch.Size to tuple if necessary
+ if isinstance(size, torch.Size):
+ size = tuple(int(x) for x in size)
+
+ # Perform the resizing operation
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/mmpose/models/utils/realnvp.py b/mmpose/models/utils/realnvp.py
new file mode 100644
index 0000000000000000000000000000000000000000..911953e8f9d1056d44a2d3538d750e89b9bd6a7a
--- /dev/null
+++ b/mmpose/models/utils/realnvp.py
@@ -0,0 +1,76 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch import distributions
+
+
+class RealNVP(nn.Module):
+ """RealNVP: a flow-based generative model
+
+ `Density estimation using Real NVP
+ arXiv: `_.
+
+ Code is modified from `the official implementation of RLE
+ `_.
+
+ See also `real-nvp-pytorch
+ `_.
+ """
+
+ @staticmethod
+ def get_scale_net():
+ """Get the scale model in a single invertable mapping."""
+ return nn.Sequential(
+ nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64),
+ nn.LeakyReLU(), nn.Linear(64, 2), nn.Tanh())
+
+ @staticmethod
+ def get_trans_net():
+ """Get the translation model in a single invertable mapping."""
+ return nn.Sequential(
+ nn.Linear(2, 64), nn.LeakyReLU(), nn.Linear(64, 64),
+ nn.LeakyReLU(), nn.Linear(64, 2))
+
+ @property
+ def prior(self):
+ """The prior distribution."""
+ return distributions.MultivariateNormal(self.loc, self.cov)
+
+ def __init__(self):
+ super(RealNVP, self).__init__()
+
+ self.register_buffer('loc', torch.zeros(2))
+ self.register_buffer('cov', torch.eye(2))
+ self.register_buffer(
+ 'mask', torch.tensor([[0, 1], [1, 0]] * 3, dtype=torch.float32))
+
+ self.s = torch.nn.ModuleList(
+ [self.get_scale_net() for _ in range(len(self.mask))])
+ self.t = torch.nn.ModuleList(
+ [self.get_trans_net() for _ in range(len(self.mask))])
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialization model weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight, gain=0.01)
+
+ def backward_p(self, x):
+ """Apply mapping form the data space to the latent space and calculate
+ the log determinant of the Jacobian matrix."""
+
+ log_det_jacob, z = x.new_zeros(x.shape[0]), x
+ for i in reversed(range(len(self.t))):
+ z_ = self.mask[i] * z
+ s = self.s[i](z_) * (1 - self.mask[i]) # torch.exp(s): betas
+ t = self.t[i](z_) * (1 - self.mask[i]) # gammas
+ z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
+ log_det_jacob -= s.sum(dim=1)
+ return z, log_det_jacob
+
+ def log_prob(self, x):
+ """Calculate the log probability of given sample in data space."""
+
+ z, log_det = self.backward_p(x)
+ return self.prior.log_prob(z) + log_det
diff --git a/mmpose/models/utils/regularizations.py b/mmpose/models/utils/regularizations.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8c7449038066016f6efb60e126111ace962fe98
--- /dev/null
+++ b/mmpose/models/utils/regularizations.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod, abstractproperty
+
+import torch
+
+
+class PytorchModuleHook(metaclass=ABCMeta):
+ """Base class for PyTorch module hook registers.
+
+ An instance of a subclass of PytorchModuleHook can be used to
+ register hook to a pytorch module using the `register` method like:
+ hook_register.register(module)
+
+ Subclasses should add/overwrite the following methods:
+ - __init__
+ - hook
+ - hook_type
+ """
+
+ @abstractmethod
+ def hook(self, *args, **kwargs):
+ """Hook function."""
+
+ @abstractproperty
+ def hook_type(self) -> str:
+ """Hook type Subclasses should overwrite this function to return a
+ string value in.
+
+ {`forward`, `forward_pre`, `backward`}
+ """
+
+ def register(self, module):
+ """Register the hook function to the module.
+
+ Args:
+ module (pytorch module): the module to register the hook.
+
+ Returns:
+ handle (torch.utils.hooks.RemovableHandle): a handle to remove
+ the hook by calling handle.remove()
+ """
+ assert isinstance(module, torch.nn.Module)
+
+ if self.hook_type == 'forward':
+ h = module.register_forward_hook(self.hook)
+ elif self.hook_type == 'forward_pre':
+ h = module.register_forward_pre_hook(self.hook)
+ elif self.hook_type == 'backward':
+ h = module.register_backward_hook(self.hook)
+ else:
+ raise ValueError(f'Invalid hook type {self.hook}')
+
+ return h
+
+
+class WeightNormClipHook(PytorchModuleHook):
+ """Apply weight norm clip regularization.
+
+ The module's parameter will be clip to a given maximum norm before each
+ forward pass.
+
+ Args:
+ max_norm (float): The maximum norm of the parameter.
+ module_param_names (str|list): The parameter name (or name list) to
+ apply weight norm clip.
+ """
+
+ def __init__(self, max_norm=1.0, module_param_names='weight'):
+ self.module_param_names = module_param_names if isinstance(
+ module_param_names, list) else [module_param_names]
+ self.max_norm = max_norm
+
+ @property
+ def hook_type(self):
+ return 'forward_pre'
+
+ def hook(self, module, _input):
+ for name in self.module_param_names:
+ assert name in module._parameters, f'{name} is not a parameter' \
+ f' of the module {type(module)}'
+ param = module._parameters[name]
+
+ with torch.no_grad():
+ m = param.norm().item()
+ if m > self.max_norm:
+ param.mul_(self.max_norm / (m + 1e-6))
diff --git a/mmpose/models/utils/rtmcc_block.py b/mmpose/models/utils/rtmcc_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e317376b2e969d01c31001748a22d6cef042f0f
--- /dev/null
+++ b/mmpose/models/utils/rtmcc_block.py
@@ -0,0 +1,299 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.bricks import DropPath
+from mmengine.utils import digit_version
+from mmengine.utils.dl_utils import TORCH_VERSION
+
+
+def rope(x, dim):
+ """Applies Rotary Position Embedding to input tensor.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ dim (int | list[int]): The spatial dimension(s) to apply
+ rotary position embedding.
+
+ Returns:
+ torch.Tensor: The tensor after applying rotary position
+ embedding.
+
+ Reference:
+ `RoFormer: Enhanced Transformer with Rotary
+ Position Embedding `_
+ """
+ shape = x.shape
+ if isinstance(dim, int):
+ dim = [dim]
+
+ spatial_shape = [shape[i] for i in dim]
+ total_len = 1
+ for i in spatial_shape:
+ total_len *= i
+
+ position = torch.reshape(
+ torch.arange(total_len, dtype=torch.int, device=x.device),
+ spatial_shape)
+
+ for i in range(dim[-1] + 1, len(shape) - 1, 1):
+ position = torch.unsqueeze(position, dim=-1)
+
+ half_size = shape[-1] // 2
+ freq_seq = -torch.arange(
+ half_size, dtype=torch.int, device=x.device) / float(half_size)
+ inv_freq = 10000**-freq_seq
+
+ sinusoid = position[..., None] * inv_freq[None, None, :]
+
+ sin = torch.sin(sinusoid)
+ cos = torch.cos(sinusoid)
+ x1, x2 = torch.chunk(x, 2, dim=-1)
+
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
+
+
+class Scale(nn.Module):
+ """Scale vector by element multiplications.
+
+ Args:
+ dim (int): The dimension of the scale vector.
+ init_value (float, optional): The initial value of the scale vector.
+ Defaults to 1.0.
+ trainable (bool, optional): Whether the scale vector is trainable.
+ Defaults to True.
+ """
+
+ def __init__(self, dim, init_value=1., trainable=True):
+ super().__init__()
+ self.scale = nn.Parameter(
+ init_value * torch.ones(dim), requires_grad=trainable)
+
+ def forward(self, x):
+ """Forward function."""
+
+ return x * self.scale
+
+
+class ScaleNorm(nn.Module):
+ """Scale Norm.
+
+ Args:
+ dim (int): The dimension of the scale vector.
+ eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
+
+ Reference:
+ `Transformers without Tears: Improving the Normalization
+ of Self-Attention `_
+ """
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim**-0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: The tensor after applying scale norm.
+ """
+
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RTMCCBlock(nn.Module):
+ """Gated Attention Unit (GAU) in RTMBlock.
+
+ Args:
+ num_token (int): The number of tokens.
+ in_token_dims (int): The input token dimension.
+ out_token_dims (int): The output token dimension.
+ expansion_factor (int, optional): The expansion factor of the
+ intermediate token dimension. Defaults to 2.
+ s (int, optional): The self-attention feature dimension.
+ Defaults to 128.
+ eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
+ dropout_rate (float, optional): The dropout rate. Defaults to 0.0.
+ drop_path (float, optional): The drop path rate. Defaults to 0.0.
+ attn_type (str, optional): Type of attention which should be one of
+ the following options:
+
+ - 'self-attn': Self-attention.
+ - 'cross-attn': Cross-attention.
+
+ Defaults to 'self-attn'.
+ act_fn (str, optional): The activation function which should be one
+ of the following options:
+
+ - 'ReLU': ReLU activation.
+ - 'SiLU': SiLU activation.
+
+ Defaults to 'SiLU'.
+ bias (bool, optional): Whether to use bias in linear layers.
+ Defaults to False.
+ use_rel_bias (bool, optional): Whether to use relative bias.
+ Defaults to True.
+ pos_enc (bool, optional): Whether to use rotary position
+ embedding. Defaults to False.
+
+ Reference:
+ `Transformer Quality in Linear Time
+ `_
+ """
+
+ def __init__(self,
+ num_token,
+ in_token_dims,
+ out_token_dims,
+ expansion_factor=2,
+ s=128,
+ eps=1e-5,
+ dropout_rate=0.,
+ drop_path=0.,
+ attn_type='self-attn',
+ act_fn='SiLU',
+ bias=False,
+ use_rel_bias=True,
+ pos_enc=False):
+
+ super(RTMCCBlock, self).__init__()
+ self.s = s
+ self.num_token = num_token
+ self.use_rel_bias = use_rel_bias
+ self.attn_type = attn_type
+ self.pos_enc = pos_enc
+ self.drop_path = DropPath(drop_path) \
+ if drop_path > 0. else nn.Identity()
+
+ self.e = int(in_token_dims * expansion_factor)
+ if use_rel_bias:
+ if attn_type == 'self-attn':
+ self.w = nn.Parameter(
+ torch.rand([2 * num_token - 1], dtype=torch.float))
+ else:
+ self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float))
+ self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float))
+ self.o = nn.Linear(self.e, out_token_dims, bias=bias)
+
+ if attn_type == 'self-attn':
+ self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias)
+ self.gamma = nn.Parameter(torch.rand((2, self.s)))
+ self.beta = nn.Parameter(torch.rand((2, self.s)))
+ else:
+ self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias)
+ self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias)
+ self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias)
+ nn.init.xavier_uniform_(self.k_fc.weight)
+ nn.init.xavier_uniform_(self.v_fc.weight)
+
+ self.ln = ScaleNorm(in_token_dims, eps=eps)
+
+ nn.init.xavier_uniform_(self.uv.weight)
+
+ if act_fn == 'SiLU':
+ assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \
+ 'SiLU activation requires PyTorch version >= 1.7'
+
+ self.act_fn = nn.SiLU(True)
+ else:
+ self.act_fn = nn.ReLU(True)
+
+ if in_token_dims == out_token_dims:
+ self.shortcut = True
+ self.res_scale = Scale(in_token_dims)
+ else:
+ self.shortcut = False
+
+ self.sqrt_s = math.sqrt(s)
+
+ self.dropout_rate = dropout_rate
+
+ if dropout_rate > 0.:
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def rel_pos_bias(self, seq_len, k_len=None):
+ """Add relative position bias."""
+
+ if self.attn_type == 'self-attn':
+ t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
+ t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
+ r = (2 * seq_len - 1) // 2
+ t = t[..., r:-r]
+ else:
+ a = rope(self.a.repeat(seq_len, 1), dim=0)
+ b = rope(self.b.repeat(k_len, 1), dim=0)
+ t = torch.bmm(a, b.permute(0, 2, 1))
+ return t
+
+ def _forward(self, inputs):
+ """GAU Forward function."""
+
+ if self.attn_type == 'self-attn':
+ x = inputs
+ else:
+ x, k, v = inputs
+
+ x = self.ln(x)
+
+ uv = self.uv(x)
+
+ if self.attn_type == 'self-attn':
+ u, v, base = torch.split(
+ self.act_fn(uv), [self.e, self.e, self.s], dim=-1)
+
+ base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta
+
+ if self.pos_enc:
+ base = rope(base, dim=1)
+
+ q, k = torch.unbind(base, dim=-2)
+
+ else:
+ u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)
+
+ k = self.k_fc(k)
+ v = self.v_fc(v)
+
+ if self.pos_enc:
+ q = rope(q, 1)
+ k = rope(k, 1)
+
+ qk = torch.bmm(q, k.permute(0, 2, 1))
+
+ if self.use_rel_bias:
+ if self.attn_type == 'self-attn':
+ bias = self.rel_pos_bias(q.size(1))
+ else:
+ bias = self.rel_pos_bias(q.size(1), k.size(1))
+ qk += bias[:, :q.size(1), :k.size(1)]
+
+ kernel = torch.square(F.relu(qk / self.sqrt_s))
+
+ if self.dropout_rate > 0.:
+ kernel = self.dropout(kernel)
+
+ x = u * torch.bmm(kernel, v)
+ x = self.o(x)
+
+ return x
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.shortcut:
+ if self.attn_type == 'cross-attn':
+ res_shortcut = x[0]
+ else:
+ res_shortcut = x
+ main_branch = self.drop_path(self._forward(x))
+ return self.res_scale(res_shortcut) + main_branch
+ else:
+ return self.drop_path(self._forward(x))
diff --git a/mmpose/models/utils/transformer.py b/mmpose/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..103b9e9970a7b96b1d5ef288d9fbf5b787838a92
--- /dev/null
+++ b/mmpose/models/utils/transformer.py
@@ -0,0 +1,369 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Sequence
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule
+from mmengine.utils import to_2tuple
+
+
+def nlc_to_nchw(x, hw_shape):
+ """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, L, C] before conversion.
+ hw_shape (Sequence[int]): The height and width of output feature map.
+
+ Returns:
+ Tensor: The output tensor of shape [N, C, H, W] after conversion.
+ """
+ H, W = hw_shape
+ assert len(x.shape) == 3
+ B, L, C = x.shape
+ assert L == H * W, 'The seq_len does not match H, W'
+ return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
+
+
+def nchw_to_nlc(x):
+ """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
+
+ Returns:
+ Tensor: The output tensor of shape [N, L, C] after conversion.
+ """
+ assert len(x.shape) == 4
+ return x.flatten(2).transpose(1, 2).contiguous()
+
+
+class AdaptivePadding(nn.Module):
+ """Applies padding to input (if needed) so that input can get fully covered
+ by filter you specified. It support two modes "same" and "corner". The
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
+ input. The "corner" mode would pad zero to bottom right.
+
+ Args:
+ kernel_size (int | tuple): Size of the kernel:
+ stride (int | tuple): Stride of the filter. Default: 1:
+ dilation (int | tuple): Spacing between kernel elements.
+ Default: 1
+ padding (str): Support "same" and "corner", "corner" mode
+ would pad zero to bottom right, and "same" mode would
+ pad zero around input. Default: "corner".
+ Example:
+ >>> kernel_size = 16
+ >>> stride = 16
+ >>> dilation = 1
+ >>> input = torch.rand(1, 1, 15, 17)
+ >>> adap_pad = AdaptivePadding(
+ >>> kernel_size=kernel_size,
+ >>> stride=stride,
+ >>> dilation=dilation,
+ >>> padding="corner")
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ >>> input = torch.rand(1, 1, 16, 17)
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ """
+
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
+
+ super(AdaptivePadding, self).__init__()
+
+ assert padding in ('same', 'corner')
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ padding = to_2tuple(padding)
+ dilation = to_2tuple(dilation)
+
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+
+ def get_pad_shape(self, input_shape):
+ """Get horizontal and vertical padding shapes."""
+
+ input_h, input_w = input_shape
+ kernel_h, kernel_w = self.kernel_size
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(input_h / stride_h)
+ output_w = math.ceil(input_w / stride_w)
+ pad_h = max((output_h - 1) * stride_h +
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
+ pad_w = max((output_w - 1) * stride_w +
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
+ return pad_h, pad_w
+
+ def forward(self, x):
+ """Forward function."""
+
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
+ if pad_h > 0 or pad_w > 0:
+ if self.padding == 'corner':
+ x = F.pad(x, [0, pad_w, 0, pad_h])
+ elif self.padding == 'same':
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2
+ ])
+ return x
+
+
+class PatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ We use a conv layer to implement PatchEmbed.
+
+ Args:
+ in_channels (int): The num of input channels. Default: 3
+ embed_dims (int): The dimensions of embedding. Default: 768
+ conv_type (str): The config dict for embedding
+ conv layer type selection. Default: "Conv2d.
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
+ stride (int): The slide stride of embedding conv.
+ Default: None (Would be set as `kernel_size`).
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int): The dilation rate of embedding conv. Default: 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ input_size (int | tuple | None): The size of input, which will be
+ used to calculate the out size. Only work when `dynamic_size`
+ is False. Default: None.
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ embed_dims=768,
+ conv_type='Conv2d',
+ kernel_size=16,
+ stride=16,
+ padding='corner',
+ dilation=1,
+ bias=True,
+ norm_cfg=None,
+ input_size=None,
+ init_cfg=None,
+ ):
+ super(PatchEmbed, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ if stride is None:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of conv
+ padding = 0
+ else:
+ self.adap_padding = None
+ padding = to_2tuple(padding)
+
+ self.projection = build_conv_layer(
+ dict(type=conv_type),
+ in_channels=in_channels,
+ out_channels=embed_dims,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+ else:
+ self.norm = None
+
+ if input_size:
+ input_size = to_2tuple(input_size)
+ # `init_out_size` would be used outside to
+ # calculate the num_patches
+ # when `use_abs_pos_embed` outside
+ self.init_input_size = input_size
+ if self.adap_padding:
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
+ input_h, input_w = input_size
+ input_h = input_h + pad_h
+ input_w = input_w + pad_w
+ input_size = (input_h, input_w)
+
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
+ self.init_out_size = (h_out, w_out)
+ else:
+ self.init_input_size = None
+ self.init_out_size = None
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class PatchMerging(BaseModule):
+ """Merge patch feature map.
+
+ This layer groups feature map by kernel_size, and applies norm and linear
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
+ merge patch, which is about 25% faster than original implementation.
+ Instead, we need to modify pretrained models for compatibility.
+
+ Args:
+ in_channels (int): The num of input channels.
+ to gets fully covered by filter and stride you specified..
+ Default: True.
+ out_channels (int): The num of output channels.
+ kernel_size (int | tuple, optional): the kernel size in the unfold
+ layer. Defaults to 2.
+ stride (int | tuple, optional): the stride of the sliding blocks in the
+ unfold layer. Default: None. (Would be set as `kernel_size`)
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int | tuple, optional): dilation parameter in the unfold
+ layer. Default: 1.
+ bias (bool, optional): Whether to add bias in linear layer or not.
+ Defaults: False.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=None,
+ padding='corner',
+ dilation=1,
+ bias=False,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if stride:
+ stride = stride
+ else:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of unfold
+ padding = 0
+ else:
+ self.adap_padding = None
+
+ padding = to_2tuple(padding)
+ self.sampler = nn.Unfold(
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride)
+
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
+ else:
+ self.norm = None
+
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
+
+ def forward(self, x, input_size):
+ """
+ Args:
+ x (Tensor): Has shape (B, H*W, C_in).
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
+ Default: None.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (Merged_H, Merged_W).
+ """
+ B, L, C = x.shape
+ assert isinstance(input_size, Sequence), f'Expect ' \
+ f'input_size is ' \
+ f'`Sequence` ' \
+ f'but get {input_size}'
+
+ H, W = input_size
+ assert L == H * W, 'input feature has wrong size'
+
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
+ # but need to modify pretrained model for compatibility
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+ H, W = x.shape[-2:]
+
+ x = self.sampler(x)
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
+
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
+ (self.sampler.kernel_size[0] - 1) -
+ 1) // self.sampler.stride[0] + 1
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
+ (self.sampler.kernel_size[1] - 1) -
+ 1) // self.sampler.stride[1] + 1
+
+ output_size = (out_h, out_w)
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
+ x = self.norm(x) if self.norm else x
+ x = self.reduction(x)
+ return x, output_size
diff --git a/mmpose/models/utils/tta.py b/mmpose/models/utils/tta.py
new file mode 100644
index 0000000000000000000000000000000000000000..0add48a422a676e131fdc2cc31a9d7bfeadc382a
--- /dev/null
+++ b/mmpose/models/utils/tta.py
@@ -0,0 +1,168 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def flip_heatmaps(heatmaps: Tensor,
+ flip_indices: Optional[List[int]] = None,
+ flip_mode: str = 'heatmap',
+ shift_heatmap: bool = True):
+ """Flip heatmaps for test-time augmentation.
+
+ Args:
+ heatmaps (Tensor): The heatmaps to flip. Should be a tensor in shape
+ [B, C, H, W]
+ flip_indices (List[int]): The indices of each keypoint's symmetric
+ keypoint. Defaults to ``None``
+ flip_mode (str): Specify the flipping mode. Options are:
+
+ - ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps
+ of symmetric keypoints according to ``flip_indices``
+ - ``'udp_combined'``: similar to ``'heatmap'`` mode but further
+ flip the x_offset values
+ - ``'offset'``: horizontally flip the offset fields and swap
+ heatmaps of symmetric keypoints according to
+ ``flip_indices``. x_offset values are also reversed
+ shift_heatmap (bool): Shift the flipped heatmaps to align with the
+ original heatmaps and improve accuracy. Defaults to ``True``
+
+ Returns:
+ Tensor: flipped heatmaps in shape [B, C, H, W]
+ """
+
+ if flip_mode == 'heatmap':
+ heatmaps = heatmaps.flip(-1)
+ if flip_indices is not None:
+ assert len(flip_indices) == heatmaps.shape[1]
+ heatmaps = heatmaps[:, flip_indices]
+ elif flip_mode == 'udp_combined':
+ B, C, H, W = heatmaps.shape
+ heatmaps = heatmaps.view(B, C // 3, 3, H, W)
+ heatmaps = heatmaps.flip(-1)
+ if flip_indices is not None:
+ assert len(flip_indices) == C // 3
+ heatmaps = heatmaps[:, flip_indices]
+ heatmaps[:, :, 1] = -heatmaps[:, :, 1]
+ heatmaps = heatmaps.view(B, C, H, W)
+
+ elif flip_mode == 'offset':
+ B, C, H, W = heatmaps.shape
+ heatmaps = heatmaps.view(B, C // 2, -1, H, W)
+ heatmaps = heatmaps.flip(-1)
+ if flip_indices is not None:
+ assert len(flip_indices) == C // 2
+ heatmaps = heatmaps[:, flip_indices]
+ heatmaps[:, :, 0] = -heatmaps[:, :, 0]
+ heatmaps = heatmaps.view(B, C, H, W)
+
+ else:
+ raise ValueError(f'Invalid flip_mode value "{flip_mode}"')
+
+ if shift_heatmap:
+ # clone data to avoid unexpected in-place operation when using CPU
+ heatmaps[..., 1:] = heatmaps[..., :-1].clone()
+
+ return heatmaps
+
+
+def flip_vectors(x_labels: Tensor, y_labels: Tensor, flip_indices: List[int]):
+ """Flip instance-level labels in specific axis for test-time augmentation.
+
+ Args:
+ x_labels (Tensor): The vector labels in x-axis to flip. Should be
+ a tensor in shape [B, C, Wx]
+ y_labels (Tensor): The vector labels in y-axis to flip. Should be
+ a tensor in shape [B, C, Wy]
+ flip_indices (List[int]): The indices of each keypoint's symmetric
+ keypoint
+ """
+ assert x_labels.ndim == 3 and y_labels.ndim == 3
+ assert len(flip_indices) == x_labels.shape[1] and len(
+ flip_indices) == y_labels.shape[1]
+ x_labels = x_labels[:, flip_indices].flip(-1)
+ y_labels = y_labels[:, flip_indices]
+
+ return x_labels, y_labels
+
+
+def flip_coordinates(coords: Tensor, flip_indices: List[int],
+ shift_coords: bool, input_size: Tuple[int, int]):
+ """Flip normalized coordinates for test-time augmentation.
+
+ Args:
+ coords (Tensor): The coordinates to flip. Should be a tensor in shape
+ [B, K, D]
+ flip_indices (List[int]): The indices of each keypoint's symmetric
+ keypoint
+ shift_coords (bool): Shift the flipped coordinates to align with the
+ original coordinates and improve accuracy. Defaults to ``True``
+ input_size (Tuple[int, int]): The size of input image in [w, h]
+ """
+ assert coords.ndim == 3
+ assert len(flip_indices) == coords.shape[1]
+
+ coords[:, :, 0] = 1.0 - coords[:, :, 0]
+
+ if shift_coords:
+ img_width = input_size[0]
+ coords[:, :, 0] -= 1.0 / img_width
+
+ coords = coords[:, flip_indices]
+ return coords
+
+
+def aggregate_heatmaps(heatmaps: List[Tensor],
+ size: Optional[Tuple[int, int]],
+ align_corners: bool = False,
+ mode: str = 'average'):
+ """Aggregate multiple heatmaps.
+
+ Args:
+ heatmaps (List[Tensor]): Multiple heatmaps to aggregate. Each should
+ be in shape (B, C, H, W)
+ size (Tuple[int, int], optional): The target size in (w, h). All
+ heatmaps will be resized to the target size. If not given, the
+ first heatmap tensor's width and height will be used as the target
+ size. Defaults to ``None``
+ align_corners (bool): Whether align corners when resizing heatmaps.
+ Defaults to ``False``
+ mode (str): Aggregation mode in one of the following:
+
+ - ``'average'``: Get average of heatmaps. All heatmaps mush have
+ the same channel number
+ - ``'concat'``: Concate the heatmaps at the channel dim
+ """
+
+ if mode not in {'average', 'concat'}:
+ raise ValueError(f'Invalid aggregation mode `{mode}`')
+
+ if size is None:
+ h, w = heatmaps[0].shape[2:4]
+ else:
+ w, h = size
+
+ for i, _heatmaps in enumerate(heatmaps):
+ assert _heatmaps.ndim == 4
+ if mode == 'average':
+ assert _heatmaps.shape[:2] == heatmaps[0].shape[:2]
+ else:
+ assert _heatmaps.shape[0] == heatmaps[0].shape[0]
+
+ if _heatmaps.shape[2:4] != (h, w):
+ heatmaps[i] = F.interpolate(
+ _heatmaps,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+
+ if mode == 'average':
+ output = sum(heatmaps).div(len(heatmaps))
+ elif mode == 'concat':
+ output = torch.cat(heatmaps, dim=1)
+ else:
+ raise ValueError()
+
+ return output
diff --git a/mmpose/registry.py b/mmpose/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b8d17c4c33dc78b457608d1b3951401cf55d52
--- /dev/null
+++ b/mmpose/registry.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""MMPose provides following registry nodes to support using modules across
+projects.
+
+Each node is a child of the root registry in MMEngine.
+More details can be found at
+https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
+"""
+
+from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
+from mmengine.registry import DATASETS as MMENGINE_DATASETS
+from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
+from mmengine.registry import HOOKS as MMENGINE_HOOKS
+from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS
+from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
+from mmengine.registry import LOOPS as MMENGINE_LOOPS
+from mmengine.registry import METRICS as MMENGINE_METRICS
+from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
+from mmengine.registry import MODELS as MMENGINE_MODELS
+from mmengine.registry import \
+ OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
+from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
+from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
+from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
+from mmengine.registry import \
+ RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
+from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
+from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
+from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
+from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
+from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
+from mmengine.registry import \
+ WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
+from mmengine.registry import Registry
+
+# Registries For Runner and the related
+# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
+RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
+# manage runner constructors that define how to initialize runners
+RUNNER_CONSTRUCTORS = Registry(
+ 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
+# manage all kinds of loops like `EpochBasedTrainLoop`
+LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
+# manage all kinds of hooks like `CheckpointHook`
+HOOKS = Registry(
+ 'hook', parent=MMENGINE_HOOKS, locations=['mmpose.engine.hooks'])
+
+# Registries For Data and the related
+# manage data-related modules
+DATASETS = Registry(
+ 'dataset', parent=MMENGINE_DATASETS, locations=['mmpose.datasets'])
+DATA_SAMPLERS = Registry(
+ 'data sampler',
+ parent=MMENGINE_DATA_SAMPLERS,
+ locations=['mmpose.datasets.samplers'])
+TRANSFORMS = Registry(
+ 'transform',
+ parent=MMENGINE_TRANSFORMS,
+ locations=['mmpose.datasets.transforms'])
+
+# manage all kinds of modules inheriting `nn.Module`
+MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmpose.models'])
+# manage all kinds of model wrappers like 'MMDistributedDataParallel'
+MODEL_WRAPPERS = Registry(
+ 'model_wrapper',
+ parent=MMENGINE_MODEL_WRAPPERS,
+ locations=['mmpose.models'])
+# manage all kinds of weight initialization modules like `Uniform`
+WEIGHT_INITIALIZERS = Registry(
+ 'weight initializer',
+ parent=MMENGINE_WEIGHT_INITIALIZERS,
+ locations=['mmpose.models'])
+# manage all kinds of batch augmentations like Mixup and CutMix.
+BATCH_AUGMENTS = Registry('batch augment', locations=['mmpose.models'])
+
+# Registries For Optimizer and the related
+# manage all kinds of optimizers like `SGD` and `Adam`
+OPTIMIZERS = Registry(
+ 'optimizer', parent=MMENGINE_OPTIMIZERS, locations=['mmpose.engine'])
+# manage optimizer wrapper
+OPTIM_WRAPPERS = Registry(
+ 'optimizer_wrapper',
+ parent=MMENGINE_OPTIM_WRAPPERS,
+ locations=['mmpose.engine'])
+# manage constructors that customize the optimization hyperparameters.
+OPTIM_WRAPPER_CONSTRUCTORS = Registry(
+ 'optimizer wrapper constructor',
+ parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
+ locations=['mmpose.engine.optim_wrappers'])
+# manage all kinds of parameter schedulers like `MultiStepLR`
+PARAM_SCHEDULERS = Registry(
+ 'parameter scheduler',
+ parent=MMENGINE_PARAM_SCHEDULERS,
+ locations=['mmpose.engine'])
+
+# manage all kinds of metrics
+METRICS = Registry(
+ 'metric', parent=MMENGINE_METRICS, locations=['mmpose.evaluation.metrics'])
+# manage all kinds of evaluators
+EVALUATORS = Registry(
+ 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmpose.evaluation'])
+
+# manage task-specific modules like anchor generators and box coders
+TASK_UTILS = Registry(
+ 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmpose.models'])
+
+# Registries For Visualizer and the related
+# manage visualizer
+VISUALIZERS = Registry(
+ 'visualizer',
+ parent=MMENGINE_VISUALIZERS,
+ locations=['mmpose.visualization'])
+# manage visualizer backend
+VISBACKENDS = Registry(
+ 'vis_backend',
+ parent=MMENGINE_VISBACKENDS,
+ locations=['mmpose.visualization'])
+
+# manage all kinds log processors
+LOG_PROCESSORS = Registry(
+ 'log processor',
+ parent=MMENGINE_LOG_PROCESSORS,
+ locations=['mmpose.visualization'])
+
+# manager keypoint encoder/decoder
+KEYPOINT_CODECS = Registry('KEYPOINT_CODECS', locations=['mmpose.codecs'])
+
+# manage inferencer
+INFERENCERS = Registry(
+ 'inferencer',
+ parent=MMENGINE_INFERENCERS,
+ locations=['mmpose.apis.inferencers'])
diff --git a/mmpose/structures/__init__.py b/mmpose/structures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4384af1cd0f7cfac9a5f8d26e34caf4c78b9fc9
--- /dev/null
+++ b/mmpose/structures/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bbox import (bbox_cs2xywh, bbox_cs2xyxy, bbox_xywh2cs, bbox_xywh2xyxy,
+ bbox_xyxy2cs, bbox_xyxy2xywh, flip_bbox,
+ get_udp_warp_matrix, get_warp_matrix)
+from .keypoint import flip_keypoints
+from .multilevel_pixel_data import MultilevelPixelData
+from .pose_data_sample import PoseDataSample
+from .utils import merge_data_samples, revert_heatmap, split_instances
+
+__all__ = [
+ 'PoseDataSample', 'MultilevelPixelData', 'bbox_cs2xywh', 'bbox_cs2xyxy',
+ 'bbox_xywh2cs', 'bbox_xywh2xyxy', 'bbox_xyxy2cs', 'bbox_xyxy2xywh',
+ 'flip_bbox', 'get_udp_warp_matrix', 'get_warp_matrix', 'flip_keypoints',
+ 'merge_data_samples', 'revert_heatmap', 'split_instances'
+]
diff --git a/mmpose/structures/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dd8cb302514cdb3d993770724b9ca27b4861308
Binary files /dev/null and b/mmpose/structures/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc b/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9572277296bf43834f212fd83c9fd46effabf2a
Binary files /dev/null and b/mmpose/structures/__pycache__/multilevel_pixel_data.cpython-38.pyc differ
diff --git a/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc b/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97b16e8532e78307a849c953394e25e19e7c5642
Binary files /dev/null and b/mmpose/structures/__pycache__/pose_data_sample.cpython-38.pyc differ
diff --git a/mmpose/structures/__pycache__/utils.cpython-38.pyc b/mmpose/structures/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ef6db350956f3f3a2e7486126e480dc2e482ffe
Binary files /dev/null and b/mmpose/structures/__pycache__/utils.cpython-38.pyc differ
diff --git a/mmpose/structures/bbox/__init__.py b/mmpose/structures/bbox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3e723918ca4b985b9e5ff3aa36b2a5c3bd2d700
--- /dev/null
+++ b/mmpose/structures/bbox/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .transforms import (bbox_cs2xywh, bbox_cs2xyxy, bbox_xywh2cs,
+ bbox_xywh2xyxy, bbox_xyxy2cs, bbox_xyxy2xywh,
+ flip_bbox, get_udp_warp_matrix, get_warp_matrix)
+
+__all__ = [
+ 'bbox_cs2xywh', 'bbox_cs2xyxy', 'bbox_xywh2cs', 'bbox_xywh2xyxy',
+ 'bbox_xyxy2cs', 'bbox_xyxy2xywh', 'flip_bbox', 'get_udp_warp_matrix',
+ 'get_warp_matrix'
+]
diff --git a/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..388945bffedb893c5e22ed4fbb4d8d0fdc842aaa
Binary files /dev/null and b/mmpose/structures/bbox/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc b/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89924501b431dc10506e465976bba57c8d74b061
Binary files /dev/null and b/mmpose/structures/bbox/__pycache__/transforms.cpython-38.pyc differ
diff --git a/mmpose/structures/bbox/transforms.py b/mmpose/structures/bbox/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..027ac0717bfa3fbafa9980b523b38496f234e474
--- /dev/null
+++ b/mmpose/structures/bbox/transforms.py
@@ -0,0 +1,361 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Tuple
+
+import cv2
+import numpy as np
+
+
+def bbox_xyxy2xywh(bbox_xyxy: np.ndarray) -> np.ndarray:
+ """Transform the bbox format from x1y1x2y2 to xywh.
+
+ Args:
+ bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5). (left, top, right, bottom, [score])
+
+ Returns:
+ np.ndarray: Bounding boxes (with scores),
+ shaped (n, 4) or (n, 5). (left, top, width, height, [score])
+ """
+ bbox_xywh = bbox_xyxy.copy()
+ bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0]
+ bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1]
+
+ return bbox_xywh
+
+
+def bbox_xywh2xyxy(bbox_xywh: np.ndarray) -> np.ndarray:
+ """Transform the bbox format from xywh to x1y1x2y2.
+
+ Args:
+ bbox_xywh (ndarray): Bounding boxes (with scores),
+ shaped (n, 4) or (n, 5). (left, top, width, height, [score])
+ Returns:
+ np.ndarray: Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5). (left, top, right, bottom, [score])
+ """
+ bbox_xyxy = bbox_xywh.copy()
+ bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0]
+ bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1]
+
+ return bbox_xyxy
+
+
+def bbox_xyxy2cs(bbox: np.ndarray,
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
+
+ Args:
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
+ as (left, top, right, bottom)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
+ (n, 2)
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
+ (n, 2)
+ """
+ # convert single bbox from (4, ) to (1, 4)
+ dim = bbox.ndim
+ if dim == 1:
+ bbox = bbox[None, :]
+
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
+
+ if dim == 1:
+ center = center[0]
+ scale = scale[0]
+
+ return center, scale
+
+
+def bbox_xywh2cs(bbox: np.ndarray,
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
+
+ Args:
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
+ as (x, y, h, w)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
+ (n, 2)
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
+ (n, 2)
+ """
+
+ # convert single bbox from (4, ) to (1, 4)
+ dim = bbox.ndim
+ if dim == 1:
+ bbox = bbox[None, :]
+
+ x, y, w, h = np.hsplit(bbox, [1, 2, 3])
+ center = np.hstack([x + w * 0.5, y + h * 0.5])
+ scale = np.hstack([w, h]) * padding
+
+ if dim == 1:
+ center = center[0]
+ scale = scale[0]
+
+ return center, scale
+
+
+def bbox_cs2xyxy(center: np.ndarray,
+ scale: np.ndarray,
+ padding: float = 1.) -> np.ndarray:
+ """Transform the bbox format from (center, scale) to (x,y,w,h).
+
+ Args:
+ center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
+ scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4)
+ """
+
+ dim = center.ndim
+ assert scale.ndim == dim
+
+ if dim == 1:
+ center = center[None, :]
+ scale = scale[None, :]
+
+ wh = scale / padding
+ xy = center - 0.5 * wh
+ bbox = np.hstack((xy, xy + wh))
+
+ if dim == 1:
+ bbox = bbox[0]
+
+ return bbox
+
+
+def bbox_cs2xywh(center: np.ndarray,
+ scale: np.ndarray,
+ padding: float = 1.) -> np.ndarray:
+ """Transform the bbox format from (center, scale) to (x,y,w,h).
+
+ Args:
+ center (ndarray): BBox center (x, y) in shape (2,) or (n, 2)
+ scale (ndarray): BBox scale (w, h) in shape (2,) or (n, 2)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ ndarray[float32]: BBox (x, y, w, h) in shape (4, ) or (n, 4)
+ """
+
+ dim = center.ndim
+ assert scale.ndim == dim
+
+ if dim == 1:
+ center = center[None, :]
+ scale = scale[None, :]
+
+ wh = scale / padding
+ xy = center - 0.5 * wh
+ bbox = np.hstack((xy, wh))
+
+ if dim == 1:
+ bbox = bbox[0]
+
+ return bbox
+
+
+def flip_bbox(bbox: np.ndarray,
+ image_size: Tuple[int, int],
+ bbox_format: str = 'xywh',
+ direction: str = 'horizontal') -> np.ndarray:
+ """Flip the bbox in the given direction.
+
+ Args:
+ bbox (np.ndarray): The bounding boxes. The shape should be (..., 4)
+ if ``bbox_format`` is ``'xyxy'`` or ``'xywh'``, and (..., 2) if
+ ``bbox_format`` is ``'center'``
+ image_size (tuple): The image shape in [w, h]
+ bbox_format (str): The bbox format. Options are ``'xywh'``, ``'xyxy'``
+ and ``'center'``.
+ direction (str): The flip direction. Options are ``'horizontal'``,
+ ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'``
+
+ Returns:
+ np.ndarray: The flipped bounding boxes.
+ """
+ direction_options = {'horizontal', 'vertical', 'diagonal'}
+ assert direction in direction_options, (
+ f'Invalid flipping direction "{direction}". '
+ f'Options are {direction_options}')
+
+ format_options = {'xywh', 'xyxy', 'center'}
+ assert bbox_format in format_options, (
+ f'Invalid bbox format "{bbox_format}". '
+ f'Options are {format_options}')
+
+ bbox_flipped = bbox.copy()
+ w, h = image_size
+
+ # TODO: consider using "integer corner" coordinate system
+ if direction == 'horizontal':
+ if bbox_format == 'xywh' or bbox_format == 'center':
+ bbox_flipped[..., 0] = w - bbox[..., 0] - 1
+ elif bbox_format == 'xyxy':
+ bbox_flipped[..., ::2] = w - bbox[..., ::2] - 1
+ elif direction == 'vertical':
+ if bbox_format == 'xywh' or bbox_format == 'center':
+ bbox_flipped[..., 1] = h - bbox[..., 1] - 1
+ elif bbox_format == 'xyxy':
+ bbox_flipped[..., 1::2] = h - bbox[..., 1::2] - 1
+ elif direction == 'diagonal':
+ if bbox_format == 'xywh' or bbox_format == 'center':
+ bbox_flipped[..., :2] = [w, h] - bbox[..., :2] - 1
+ elif bbox_format == 'xyxy':
+ bbox_flipped[...] = [w, h, w, h] - bbox - 1
+
+ return bbox_flipped
+
+
+def get_udp_warp_matrix(
+ center: np.ndarray,
+ scale: np.ndarray,
+ rot: float,
+ output_size: Tuple[int, int],
+) -> np.ndarray:
+ """Calculate the affine transformation matrix under the unbiased
+ constraint. See `UDP (CVPR 2020)`_ for details.
+
+ Note:
+
+ - The bbox number: N
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (tuple): Size ([w, h]) of the output image
+
+ Returns:
+ np.ndarray: A 2x3 transformation matrix
+
+ .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
+ """
+ assert len(center) == 2
+ assert len(scale) == 2
+ assert len(output_size) == 2
+
+ input_size = center * 2
+ rot_rad = np.deg2rad(rot)
+ warp_mat = np.zeros((2, 3), dtype=np.float32)
+ scale_x = (output_size[0] - 1) / scale[0]
+ scale_y = (output_size[1] - 1) / scale[1]
+ warp_mat[0, 0] = math.cos(rot_rad) * scale_x
+ warp_mat[0, 1] = -math.sin(rot_rad) * scale_x
+ warp_mat[0, 2] = scale_x * (-0.5 * input_size[0] * math.cos(rot_rad) +
+ 0.5 * input_size[1] * math.sin(rot_rad) +
+ 0.5 * scale[0])
+ warp_mat[1, 0] = math.sin(rot_rad) * scale_y
+ warp_mat[1, 1] = math.cos(rot_rad) * scale_y
+ warp_mat[1, 2] = scale_y * (-0.5 * input_size[0] * math.sin(rot_rad) -
+ 0.5 * input_size[1] * math.cos(rot_rad) +
+ 0.5 * scale[1])
+ return warp_mat
+
+
+def get_warp_matrix(center: np.ndarray,
+ scale: np.ndarray,
+ rot: float,
+ output_size: Tuple[int, int],
+ shift: Tuple[float, float] = (0., 0.),
+ inv: bool = False) -> np.ndarray:
+ """Calculate the affine transformation matrix that can warp the bbox area
+ in the input image to the output size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
+ destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: A 2x3 transformation matrix
+ """
+ assert len(center) == 2
+ assert len(scale) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+
+ shift = np.array(shift)
+ src_w = scale[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.deg2rad(rot)
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale * shift
+ src[1, :] = center + src_dir + scale * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+ return warp_mat
+
+
+def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
+ """Rotate a point by an angle.
+
+ Args:
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
+ angle_rad (float): rotation angle in radian
+
+ Returns:
+ np.ndarray: Rotated point in shape (2, )
+ """
+
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
+ return rot_mat @ pt
+
+
+def _get_3rd_point(a: np.ndarray, b: np.ndarray):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ direction = a - b
+ c = b + np.r_[-direction[1], direction[0]]
+ return c
diff --git a/mmpose/structures/keypoint/__init__.py b/mmpose/structures/keypoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8d5a24c7a8f8f08b4ef8f37b8628249e3734ba7
--- /dev/null
+++ b/mmpose/structures/keypoint/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from .transforms import flip_keypoints
+
+__all__ = ['flip_keypoints']
diff --git a/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc b/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f346107454a95a3b9c660968a43b9b88beb2ae8a
Binary files /dev/null and b/mmpose/structures/keypoint/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc b/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22560b454c20a0244b24d8e9c500b33b0b9a7eeb
Binary files /dev/null and b/mmpose/structures/keypoint/__pycache__/transforms.cpython-38.pyc differ
diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..99adaa130619e2e6040d5cc2d055ea09e99bec23
--- /dev/null
+++ b/mmpose/structures/keypoint/transforms.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+
+def flip_keypoints(keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray],
+ image_size: Tuple[int, int],
+ flip_indices: List[int],
+ direction: str = 'horizontal'
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ """Flip keypoints in the given direction.
+
+ Note:
+
+ - keypoint number: K
+ - keypoint dimension: D
+
+ Args:
+ keypoints (np.ndarray): Keypoints in shape (..., K, D)
+ keypoints_visible (np.ndarray, optional): The visibility of keypoints
+ in shape (..., K, 1). Set ``None`` if the keypoint visibility is
+ unavailable
+ image_size (tuple): The image shape in [w, h]
+ flip_indices (List[int]): The indices of each keypoint's symmetric
+ keypoint
+ direction (str): The flip direction. Options are ``'horizontal'``,
+ ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'``
+
+ Returns:
+ tuple:
+ - keypoints_flipped (np.ndarray): Flipped keypoints in shape
+ (..., K, D)
+ - keypoints_visible_flipped (np.ndarray, optional): Flipped keypoints'
+ visibility in shape (..., K, 1). Return ``None`` if the input
+ ``keypoints_visible`` is ``None``
+ """
+
+ assert keypoints.shape[:-1] == keypoints_visible.shape, (
+ f'Mismatched shapes of keypoints {keypoints.shape} and '
+ f'keypoints_visible {keypoints_visible.shape}')
+
+ direction_options = {'horizontal', 'vertical', 'diagonal'}
+ assert direction in direction_options, (
+ f'Invalid flipping direction "{direction}". '
+ f'Options are {direction_options}')
+
+ # swap the symmetric keypoint pairs
+ if direction == 'horizontal' or direction == 'vertical':
+ keypoints = keypoints[..., flip_indices, :]
+ if keypoints_visible is not None:
+ keypoints_visible = keypoints_visible[..., flip_indices]
+
+ # flip the keypoints
+ w, h = image_size
+ if direction == 'horizontal':
+ keypoints[..., 0] = w - 1 - keypoints[..., 0]
+ elif direction == 'vertical':
+ keypoints[..., 1] = h - 1 - keypoints[..., 1]
+ else:
+ keypoints = [w, h] - keypoints - 1
+
+ return keypoints, keypoints_visible
diff --git a/mmpose/structures/multilevel_pixel_data.py b/mmpose/structures/multilevel_pixel_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..bea191e7297c233cc129f2da09ab5a4c6793fa0f
--- /dev/null
+++ b/mmpose/structures/multilevel_pixel_data.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import abc
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
+
+import numpy as np
+import torch
+from mmengine.structures import BaseDataElement, PixelData
+from mmengine.utils import is_list_of
+
+IndexType = Union[str, slice, int, list, torch.LongTensor,
+ torch.cuda.LongTensor, torch.BoolTensor,
+ torch.cuda.BoolTensor, np.ndarray]
+
+
+class MultilevelPixelData(BaseDataElement):
+ """Data structure for multi-level pixel-wise annotations or predictions.
+
+ All data items in ``data_fields`` of ``MultilevelPixelData`` are lists
+ of np.ndarray or torch.Tensor, and should meet the following requirements:
+
+ - Have the same length, which is the number of levels
+ - At each level, the data should have 3 dimensions in order of channel,
+ height and weight
+ - At each level, the data should have the same height and weight
+
+ Examples:
+ >>> metainfo = dict(num_keypoints=17)
+ >>> sizes = [(64, 48), (128, 96), (256, 192)]
+ >>> heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
+ >>> masks = [torch.rand(1, h, w) for h, w in sizes]
+ >>> data = MultilevelPixelData(metainfo=metainfo,
+ ... heatmaps=heatmaps,
+ ... masks=masks)
+
+ >>> # get data item
+ >>> heatmaps = data.heatmaps # A list of 3 numpy.ndarrays
+ >>> masks = data.masks # A list of 3 torch.Tensors
+
+ >>> # get level
+ >>> data_l0 = data[0] # PixelData with fields 'heatmaps' and 'masks'
+ >>> data.nlevel
+ 3
+
+ >>> # get shape
+ >>> data.shape
+ ((64, 48), (128, 96), (256, 192))
+
+ >>> # set
+ >>> offset_maps = [torch.rand(2, h, w) for h, w in sizes]
+ >>> data.offset_maps = offset_maps
+ """
+
+ def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
+ object.__setattr__(self, '_nlevel', None)
+ super().__init__(metainfo=metainfo, **kwargs)
+
+ @property
+ def nlevel(self):
+ """Return the level number.
+
+ Returns:
+ Optional[int]: The level number, or ``None`` if the data has not
+ been assigned.
+ """
+ return self._nlevel
+
+ def __getitem__(self, item: Union[int, str, list,
+ slice]) -> Union[PixelData, Sequence]:
+ if isinstance(item, int):
+ if self.nlevel is None or item >= self.nlevel:
+ raise IndexError(
+ f'Lcale index {item} out of range ({self.nlevel})')
+ return self.get(f'_level_{item}')
+
+ if isinstance(item, str):
+ if item not in self:
+ raise KeyError(item)
+ return getattr(self, item)
+
+ # TODO: support indexing by list and slice over levels
+ raise NotImplementedError(
+ f'{self.__class__.__name__} does not support index type '
+ f'{type(item)}')
+
+ def levels(self) -> List[PixelData]:
+ if self.nlevel:
+ return list(self[i] for i in range(self.nlevel))
+ return []
+
+ @property
+ def shape(self) -> Optional[Tuple[Tuple]]:
+ """Get the shape of multi-level pixel data.
+
+ Returns:
+ Optional[tuple]: A tuple of data shape at each level, or ``None``
+ if the data has not been assigned.
+ """
+ if self.nlevel is None:
+ return None
+
+ return tuple(level.shape for level in self.levels())
+
+ def set_data(self, data: dict) -> None:
+ """Set or change key-value pairs in ``data_field`` by parameter
+ ``data``.
+
+ Args:
+ data (dict): A dict contains annotations of image or
+ model predictions.
+ """
+ assert isinstance(data,
+ dict), f'meta should be a `dict` but got {data}'
+ for k, v in data.items():
+ self.set_field(v, k, field_type='data')
+
+ def set_field(self,
+ value: Any,
+ name: str,
+ dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
+ field_type: str = 'data') -> None:
+ """Special method for set union field, used as property.setter
+ functions."""
+ assert field_type in ['metainfo', 'data']
+ if dtype is not None:
+ assert isinstance(
+ value,
+ dtype), f'{value} should be a {dtype} but got {type(value)}'
+
+ if name.startswith('_level_'):
+ raise AttributeError(
+ f'Cannot set {name} to be a field because the pattern '
+ '<_level_{n}> is reserved for inner data field')
+
+ if field_type == 'metainfo':
+ if name in self._data_fields:
+ raise AttributeError(
+ f'Cannot set {name} to be a field of metainfo '
+ f'because {name} is already a data field')
+ self._metainfo_fields.add(name)
+
+ else:
+ if name in self._metainfo_fields:
+ raise AttributeError(
+ f'Cannot set {name} to be a field of data '
+ f'because {name} is already a metainfo field')
+
+ if not isinstance(value, abc.Sequence):
+ raise TypeError(
+ 'The value should be a sequence (of numpy.ndarray or'
+ f'torch.Tesnor), but got a {type(value)}')
+
+ if len(value) == 0:
+ raise ValueError('Setting empty value is not allowed')
+
+ if not isinstance(value[0], (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ 'The value should be a sequence of numpy.ndarray or'
+ f'torch.Tesnor, but got a sequence of {type(value[0])}')
+
+ if self.nlevel is not None:
+ assert len(value) == self.nlevel, (
+ f'The length of the value ({len(value)}) should match the'
+ f'number of the levels ({self.nlevel})')
+ else:
+ object.__setattr__(self, '_nlevel', len(value))
+ for i in range(self.nlevel):
+ object.__setattr__(self, f'_level_{i}', PixelData())
+
+ for i, v in enumerate(value):
+ self[i].set_field(v, name, field_type='data')
+
+ self._data_fields.add(name)
+
+ object.__setattr__(self, name, value)
+
+ def __delattr__(self, item: str):
+ """delete the item in dataelement.
+
+ Args:
+ item (str): The key to delete.
+ """
+ if item in ('_metainfo_fields', '_data_fields'):
+ raise AttributeError(f'{item} has been used as a '
+ 'private attribute, which is immutable. ')
+
+ if item in self._metainfo_fields:
+ super().__delattr__(item)
+ else:
+ for level in self.levels():
+ level.__delattr__(item)
+ self._data_fields.remove(item)
+
+ def __getattr__(self, name):
+ if name in {'_data_fields', '_metainfo_fields'
+ } or name not in self._data_fields:
+ raise AttributeError(
+ f'\'{self.__class__.__name__}\' object has no attribute '
+ f'\'{name}\'')
+
+ return [getattr(level, name) for level in self.levels()]
+
+ def pop(self, *args) -> Any:
+ """pop property in data and metainfo as the same as python."""
+ assert len(args) < 3, '``pop`` get more than 2 arguments'
+ name = args[0]
+ if name in self._metainfo_fields:
+ self._metainfo_fields.remove(name)
+ return self.__dict__.pop(*args)
+
+ elif name in self._data_fields:
+ self._data_fields.remove(name)
+ return [level.pop(*args) for level in self.levels()]
+
+ # with default value
+ elif len(args) == 2:
+ return args[1]
+ else:
+ # don't just use 'self.__dict__.pop(*args)' for only popping key in
+ # metainfo or data
+ raise KeyError(f'{args[0]} is not contained in metainfo or data')
+
+ def _convert(self, apply_to: Type,
+ func: Callable[[Any], Any]) -> 'MultilevelPixelData':
+ """Convert data items with the given function.
+
+ Args:
+ apply_to (Type): The type of data items to apply the conversion
+ func (Callable): The conversion function that takes a data item
+ as the input and return the converted result
+
+ Returns:
+ MultilevelPixelData: the converted data element.
+ """
+ new_data = self.new()
+ for k, v in self.items():
+ if is_list_of(v, apply_to):
+ v = [func(_v) for _v in v]
+ data = {k: v}
+ new_data.set_data(data)
+ return new_data
+
+ def cpu(self) -> 'MultilevelPixelData':
+ """Convert all tensors to CPU in data."""
+ return self._convert(apply_to=torch.Tensor, func=lambda x: x.cpu())
+
+ def cuda(self) -> 'MultilevelPixelData':
+ """Convert all tensors to GPU in data."""
+ return self._convert(apply_to=torch.Tensor, func=lambda x: x.cuda())
+
+ def detach(self) -> 'MultilevelPixelData':
+ """Detach all tensors in data."""
+ return self._convert(apply_to=torch.Tensor, func=lambda x: x.detach())
+
+ def numpy(self) -> 'MultilevelPixelData':
+ """Convert all tensor to np.narray in data."""
+ return self._convert(
+ apply_to=torch.Tensor, func=lambda x: x.detach().cpu().numpy())
+
+ def to_tensor(self) -> 'MultilevelPixelData':
+ """Convert all tensor to np.narray in data."""
+ return self._convert(
+ apply_to=np.ndarray, func=lambda x: torch.from_numpy(x))
+
+ # Tensor-like methods
+ def to(self, *args, **kwargs) -> 'MultilevelPixelData':
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if hasattr(v[0], 'to'):
+ v = [v_.to(*args, **kwargs) for v_ in v]
+ data = {k: v}
+ new_data.set_data(data)
+ return new_data
diff --git a/mmpose/structures/pose_data_sample.py b/mmpose/structures/pose_data_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1d69034e319c6c43870f56e06661c536adf5e2
--- /dev/null
+++ b/mmpose/structures/pose_data_sample.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+from mmengine.structures import BaseDataElement, InstanceData, PixelData
+
+from mmpose.structures import MultilevelPixelData
+
+
+class PoseDataSample(BaseDataElement):
+ """The base data structure of MMPose that is used as the interface between
+ modules.
+
+ The attributes of ``PoseDataSample`` includes:
+
+ - ``gt_instances``(InstanceData): Ground truth of instances with
+ keypoint annotations
+ - ``pred_instances``(InstanceData): Instances with keypoint
+ predictions
+ - ``gt_fields``(PixelData): Ground truth of spatial distribution
+ annotations like keypoint heatmaps and part affine fields (PAF)
+ - ``pred_fields``(PixelData): Predictions of spatial distributions
+
+ Examples:
+ >>> import torch
+ >>> from mmengine.structures import InstanceData, PixelData
+ >>> from mmpose.structures import PoseDataSample
+
+ >>> pose_meta = dict(img_shape=(800, 1216),
+ ... crop_size=(256, 192),
+ ... heatmap_size=(64, 48))
+ >>> gt_instances = InstanceData()
+ >>> gt_instances.bboxes = torch.rand((1, 4))
+ >>> gt_instances.keypoints = torch.rand((1, 17, 2))
+ >>> gt_instances.keypoints_visible = torch.rand((1, 17, 1))
+ >>> gt_fields = PixelData()
+ >>> gt_fields.heatmaps = torch.rand((17, 64, 48))
+
+ >>> data_sample = PoseDataSample(gt_instances=gt_instances,
+ ... gt_fields=gt_fields,
+ ... metainfo=pose_meta)
+ >>> assert 'img_shape' in data_sample
+ >>> len(data_sample.gt_intances)
+ 1
+ """
+
+ @property
+ def gt_instances(self) -> InstanceData:
+ return self._gt_instances
+
+ @gt_instances.setter
+ def gt_instances(self, value: InstanceData):
+ self.set_field(value, '_gt_instances', dtype=InstanceData)
+
+ @gt_instances.deleter
+ def gt_instances(self):
+ del self._gt_instances
+
+ @property
+ def gt_instance_labels(self) -> InstanceData:
+ return self._gt_instance_labels
+
+ @gt_instance_labels.setter
+ def gt_instance_labels(self, value: InstanceData):
+ self.set_field(value, '_gt_instance_labels', dtype=InstanceData)
+
+ @gt_instance_labels.deleter
+ def gt_instance_labels(self):
+ del self._gt_instance_labels
+
+ @property
+ def pred_instances(self) -> InstanceData:
+ return self._pred_instances
+
+ @pred_instances.setter
+ def pred_instances(self, value: InstanceData):
+ self.set_field(value, '_pred_instances', dtype=InstanceData)
+
+ @pred_instances.deleter
+ def pred_instances(self):
+ del self._pred_instances
+
+ @property
+ def gt_fields(self) -> Union[PixelData, MultilevelPixelData]:
+ return self._gt_fields
+
+ @gt_fields.setter
+ def gt_fields(self, value: Union[PixelData, MultilevelPixelData]):
+ self.set_field(value, '_gt_fields', dtype=type(value))
+
+ @gt_fields.deleter
+ def gt_fields(self):
+ del self._gt_fields
+
+ @property
+ def pred_fields(self) -> PixelData:
+ return self._pred_heatmaps
+
+ @pred_fields.setter
+ def pred_fields(self, value: PixelData):
+ self.set_field(value, '_pred_heatmaps', dtype=PixelData)
+
+ @pred_fields.deleter
+ def pred_fields(self):
+ del self._pred_heatmaps
diff --git a/mmpose/structures/utils.py b/mmpose/structures/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..882cda86037bd2c5d68b938df23dfe91b4007957
--- /dev/null
+++ b/mmpose/structures/utils.py
@@ -0,0 +1,138 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import List
+
+import cv2
+import numpy as np
+import torch
+from mmengine.structures import InstanceData, PixelData
+from mmengine.utils import is_list_of
+
+from .bbox.transforms import get_warp_matrix
+from .pose_data_sample import PoseDataSample
+
+
+def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample:
+ """Merge the given data samples into a single data sample.
+
+ This function can be used to merge the top-down predictions with
+ bboxes from the same image. The merged data sample will contain all
+ instances from the input data samples, and the identical metainfo with
+ the first input data sample.
+
+ Args:
+ data_samples (List[:obj:`PoseDataSample`]): The data samples to
+ merge
+
+ Returns:
+ PoseDataSample: The merged data sample.
+ """
+
+ if not is_list_of(data_samples, PoseDataSample):
+ raise ValueError('Invalid input type, should be a list of '
+ ':obj:`PoseDataSample`')
+
+ if len(data_samples) == 0:
+ warnings.warn('Try to merge an empty list of data samples.')
+ return PoseDataSample()
+
+ merged = PoseDataSample(metainfo=data_samples[0].metainfo)
+
+ if 'gt_instances' in data_samples[0]:
+ merged.gt_instances = InstanceData.cat(
+ [d.gt_instances for d in data_samples])
+
+ if 'pred_instances' in data_samples[0]:
+ merged.pred_instances = InstanceData.cat(
+ [d.pred_instances for d in data_samples])
+
+ if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[
+ 0].pred_fields:
+ reverted_heatmaps = [
+ revert_heatmap(data_sample.pred_fields.heatmaps,
+ data_sample.gt_instances.bbox_centers,
+ data_sample.gt_instances.bbox_scales,
+ data_sample.ori_shape)
+ for data_sample in data_samples
+ ]
+
+ merged_heatmaps = np.max(reverted_heatmaps, axis=0)
+ pred_fields = PixelData()
+ pred_fields.set_data(dict(heatmaps=merged_heatmaps))
+ merged.pred_fields = pred_fields
+
+ if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[
+ 0].gt_fields:
+ reverted_heatmaps = [
+ revert_heatmap(data_sample.gt_fields.heatmaps,
+ data_sample.gt_instances.bbox_centers,
+ data_sample.gt_instances.bbox_scales,
+ data_sample.ori_shape)
+ for data_sample in data_samples
+ ]
+
+ merged_heatmaps = np.max(reverted_heatmaps, axis=0)
+ gt_fields = PixelData()
+ gt_fields.set_data(dict(heatmaps=merged_heatmaps))
+ merged.gt_fields = gt_fields
+
+ return merged
+
+
+def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape):
+ """Revert predicted heatmap on the original image.
+
+ Args:
+ heatmap (np.ndarray or torch.tensor): predicted heatmap.
+ bbox_center (np.ndarray): bounding box center coordinate.
+ bbox_scale (np.ndarray): bounding box scale.
+ img_shape (tuple or list): size of original image.
+ """
+ if torch.is_tensor(heatmap):
+ heatmap = heatmap.cpu().detach().numpy()
+
+ ndim = heatmap.ndim
+ # [K, H, W] -> [H, W, K]
+ if ndim == 3:
+ heatmap = heatmap.transpose(1, 2, 0)
+
+ hm_h, hm_w = heatmap.shape[:2]
+ img_h, img_w = img_shape
+ warp_mat = get_warp_matrix(
+ bbox_center.reshape((2, )),
+ bbox_scale.reshape((2, )),
+ rot=0,
+ output_size=(hm_w, hm_h),
+ inv=True)
+
+ heatmap = cv2.warpAffine(
+ heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR)
+
+ # [H, W, K] -> [K, H, W]
+ if ndim == 3:
+ heatmap = heatmap.transpose(2, 0, 1)
+
+ return heatmap
+
+
+def split_instances(instances: InstanceData) -> List[InstanceData]:
+ """Convert instances into a list where each element is a dict that contains
+ information about one instance."""
+ results = []
+
+ # return an empty list if there is no instance detected by the model
+ if instances is None:
+ return results
+
+ for i in range(len(instances.keypoints)):
+ result = dict(
+ keypoints=instances.keypoints[i].tolist(),
+ keypoint_scores=instances.keypoint_scores[i].tolist(),
+ )
+ if 'bboxes' in instances:
+ result['bbox'] = instances.bboxes[i].tolist(),
+ if 'bbox_scores' in instances:
+ result['bbox_score'] = instances.bbox_scores[i]
+ results.append(result)
+
+ return results
diff --git a/mmpose/testing/__init__.py b/mmpose/testing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5612dac6c66e3bf7c2bad86154ac62c9d5e9529a
--- /dev/null
+++ b/mmpose/testing/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ._utils import (get_coco_sample, get_config_file, get_packed_inputs,
+ get_pose_estimator_cfg, get_repo_dir)
+
+__all__ = [
+ 'get_packed_inputs', 'get_coco_sample', 'get_config_file',
+ 'get_pose_estimator_cfg', 'get_repo_dir'
+]
diff --git a/mmpose/testing/_utils.py b/mmpose/testing/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1908129be858c5ff7c5b5b9a6e14e0f03858e53d
--- /dev/null
+++ b/mmpose/testing/_utils.py
@@ -0,0 +1,248 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from copy import deepcopy
+from typing import Optional
+
+import numpy as np
+import torch
+from mmengine.config import Config
+from mmengine.dataset import pseudo_collate
+from mmengine.structures import InstanceData, PixelData
+
+from mmpose.structures import MultilevelPixelData, PoseDataSample
+from mmpose.structures.bbox import bbox_xyxy2cs
+
+
+def get_coco_sample(
+ img_shape=(240, 320),
+ img_fill: Optional[int] = None,
+ num_instances=1,
+ with_bbox_cs=True,
+ with_img_mask=False,
+ random_keypoints_visible=False,
+ non_occlusion=False):
+ """Create a dummy data sample in COCO style."""
+ rng = np.random.RandomState(0)
+ h, w = img_shape
+ if img_fill is None:
+ img = np.random.randint(0, 256, (h, w, 3), dtype=np.uint8)
+ else:
+ img = np.full((h, w, 3), img_fill, dtype=np.uint8)
+
+ if non_occlusion:
+ bbox = _rand_bboxes(rng, num_instances, w / num_instances, h)
+ for i in range(num_instances):
+ bbox[i, 0::2] += w / num_instances * i
+ else:
+ bbox = _rand_bboxes(rng, num_instances, w, h)
+
+ keypoints = _rand_keypoints(rng, bbox, 17)
+ if random_keypoints_visible:
+ keypoints_visible = np.random.randint(0, 2, (num_instances,
+ 17)).astype(np.float32)
+ else:
+ keypoints_visible = np.full((num_instances, 17), 1, dtype=np.float32)
+
+ upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ lower_body_ids = [11, 12, 13, 14, 15, 16]
+ flip_pairs = [[2, 1], [1, 2], [4, 3], [3, 4], [6, 5], [5, 6], [8, 7],
+ [7, 8], [10, 9], [9, 10], [12, 11], [11, 12], [14, 13],
+ [13, 14], [16, 15], [15, 16]]
+ flip_indices = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
+ dataset_keypoint_weights = np.array([
+ 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5,
+ 1.5
+ ]).astype(np.float32)
+
+ data = {
+ 'img': img,
+ 'img_shape': img_shape,
+ 'ori_shape': img_shape,
+ 'bbox': bbox,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'upper_body_ids': upper_body_ids,
+ 'lower_body_ids': lower_body_ids,
+ 'flip_pairs': flip_pairs,
+ 'flip_indices': flip_indices,
+ 'dataset_keypoint_weights': dataset_keypoint_weights,
+ 'invalid_segs': [],
+ }
+
+ if with_bbox_cs:
+ data['bbox_center'], data['bbox_scale'] = bbox_xyxy2cs(data['bbox'])
+
+ if with_img_mask:
+ data['img_mask'] = np.random.randint(0, 2, (h, w), dtype=np.uint8)
+
+ return data
+
+
+def get_packed_inputs(batch_size=2,
+ num_instances=1,
+ num_keypoints=17,
+ num_levels=1,
+ img_shape=(256, 192),
+ input_size=(192, 256),
+ heatmap_size=(48, 64),
+ simcc_split_ratio=2.0,
+ with_heatmap=True,
+ with_reg_label=True,
+ with_simcc_label=True):
+ """Create a dummy batch of model inputs and data samples."""
+ rng = np.random.RandomState(0)
+
+ inputs_list = []
+ for idx in range(batch_size):
+ inputs = dict()
+
+ # input
+ h, w = img_shape
+ image = rng.randint(0, 255, size=(3, h, w), dtype=np.uint8)
+ inputs['inputs'] = torch.from_numpy(image)
+
+ # meta
+ img_meta = {
+ 'id': idx,
+ 'img_id': idx,
+ 'img_path': '.png',
+ 'img_shape': img_shape,
+ 'input_size': input_size,
+ 'flip': False,
+ 'flip_direction': None,
+ 'flip_indices': list(range(num_keypoints))
+ }
+
+ np.random.shuffle(img_meta['flip_indices'])
+ data_sample = PoseDataSample(metainfo=img_meta)
+
+ # gt_instance
+ gt_instances = InstanceData()
+ gt_instance_labels = InstanceData()
+ bboxes = _rand_bboxes(rng, num_instances, w, h)
+ bbox_centers, bbox_scales = bbox_xyxy2cs(bboxes)
+
+ keypoints = _rand_keypoints(rng, bboxes, num_keypoints)
+ keypoints_visible = np.ones((num_instances, num_keypoints),
+ dtype=np.float32)
+
+ # [N, K] -> [N, num_levels, K]
+ # keep the first dimension as the num_instances
+ if num_levels > 1:
+ keypoint_weights = np.tile(keypoints_visible[:, None],
+ (1, num_levels, 1))
+ else:
+ keypoint_weights = keypoints_visible.copy()
+
+ gt_instances.bboxes = bboxes
+ gt_instances.bbox_centers = bbox_centers
+ gt_instances.bbox_scales = bbox_scales
+ gt_instances.bbox_scores = np.ones((num_instances, ), dtype=np.float32)
+ gt_instances.keypoints = keypoints
+ gt_instances.keypoints_visible = keypoints_visible
+
+ gt_instance_labels.keypoint_weights = torch.FloatTensor(
+ keypoint_weights)
+
+ if with_reg_label:
+ gt_instance_labels.keypoint_labels = torch.FloatTensor(keypoints /
+ input_size)
+
+ if with_simcc_label:
+ len_x = np.around(input_size[0] * simcc_split_ratio)
+ len_y = np.around(input_size[1] * simcc_split_ratio)
+ gt_instance_labels.keypoint_x_labels = torch.FloatTensor(
+ _rand_simcc_label(rng, num_instances, num_keypoints, len_x))
+ gt_instance_labels.keypoint_y_labels = torch.FloatTensor(
+ _rand_simcc_label(rng, num_instances, num_keypoints, len_y))
+
+ # gt_fields
+ if with_heatmap:
+ if num_levels == 1:
+ gt_fields = PixelData()
+ # generate single-level heatmaps
+ W, H = heatmap_size
+ heatmaps = rng.rand(num_keypoints, H, W)
+ gt_fields.heatmaps = torch.FloatTensor(heatmaps)
+ else:
+ # generate multilevel heatmaps
+ heatmaps = []
+ for _ in range(num_levels):
+ W, H = heatmap_size
+ heatmaps_ = rng.rand(num_keypoints, H, W)
+ heatmaps.append(torch.FloatTensor(heatmaps_))
+ # [num_levels*K, H, W]
+ gt_fields = MultilevelPixelData()
+ gt_fields.heatmaps = heatmaps
+ data_sample.gt_fields = gt_fields
+
+ data_sample.gt_instances = gt_instances
+ data_sample.gt_instance_labels = gt_instance_labels
+
+ inputs['data_samples'] = data_sample
+ inputs_list.append(inputs)
+
+ packed_inputs = pseudo_collate(inputs_list)
+ return packed_inputs
+
+
+def _rand_keypoints(rng, bboxes, num_keypoints):
+ n = bboxes.shape[0]
+ relative_pos = rng.rand(n, num_keypoints, 2)
+ keypoints = relative_pos * bboxes[:, None, :2] + (
+ 1 - relative_pos) * bboxes[:, None, 2:4]
+
+ return keypoints
+
+
+def _rand_simcc_label(rng, num_instances, num_keypoints, len_feats):
+ simcc_label = rng.rand(num_instances, num_keypoints, int(len_feats))
+ return simcc_label
+
+
+def _rand_bboxes(rng, num_instances, img_w, img_h):
+ cx, cy = rng.rand(num_instances, 2).T
+ bw, bh = 0.2 + 0.8 * rng.rand(num_instances, 2).T
+
+ tl_x = ((cx * img_w) - (img_w * bw / 2)).clip(0, img_w)
+ tl_y = ((cy * img_h) - (img_h * bh / 2)).clip(0, img_h)
+ br_x = ((cx * img_w) + (img_w * bw / 2)).clip(0, img_w)
+ br_y = ((cy * img_h) + (img_h * bh / 2)).clip(0, img_h)
+
+ bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T
+ return bboxes
+
+
+def get_repo_dir():
+ """Return the path of the MMPose repo directory."""
+ try:
+ # Assume the function in invoked is the source mmpose repo
+ repo_dir = osp.dirname(osp.dirname(osp.dirname(__file__)))
+ except NameError:
+ # For IPython development when __file__ is not defined
+ import mmpose
+ repo_dir = osp.dirname(osp.dirname(mmpose.__file__))
+
+ return repo_dir
+
+
+def get_config_file(fn: str):
+ """Return full path of a config file from the given relative path."""
+ repo_dir = get_repo_dir()
+ if fn.startswith('configs'):
+ fn_config = osp.join(repo_dir, fn)
+ else:
+ fn_config = osp.join(repo_dir, 'configs', fn)
+
+ if not osp.isfile(fn_config):
+ raise FileNotFoundError(f'Cannot find config file {fn_config}')
+
+ return fn_config
+
+
+def get_pose_estimator_cfg(fn: str):
+ """Load model config from a config file."""
+
+ fn_config = get_config_file(fn)
+ config = Config.fromfile(fn_config)
+ return deepcopy(config.model)
diff --git a/mmpose/utils/__init__.py b/mmpose/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48ca01cea586d0e35c2f6daf3138ab94fe4e613
--- /dev/null
+++ b/mmpose/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .camera import SimpleCamera, SimpleCameraTorch
+from .collect_env import collect_env
+from .config_utils import adapt_mmdet_pipeline
+from .logger import get_root_logger
+from .setup_env import register_all_modules, setup_multi_processes
+from .timer import StopWatch
+
+__all__ = [
+ 'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes',
+ 'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch',
+ 'adapt_mmdet_pipeline'
+]
diff --git a/mmpose/utils/__pycache__/__init__.cpython-38.pyc b/mmpose/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71aea1227c1c32fb2e53744620ea9bd2f2f8417b
Binary files /dev/null and b/mmpose/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/camera.cpython-38.pyc b/mmpose/utils/__pycache__/camera.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e2183ee7296dd3f7e4c57edfb47ff9055f27334
Binary files /dev/null and b/mmpose/utils/__pycache__/camera.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/collect_env.cpython-38.pyc b/mmpose/utils/__pycache__/collect_env.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ced64105c27986680519500642887e6659699a8
Binary files /dev/null and b/mmpose/utils/__pycache__/collect_env.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/config_utils.cpython-38.pyc b/mmpose/utils/__pycache__/config_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3ca7e8d258ba1289c85923cd650aebadf596f9e
Binary files /dev/null and b/mmpose/utils/__pycache__/config_utils.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/logger.cpython-38.pyc b/mmpose/utils/__pycache__/logger.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5657e594754b35bb2a067ea02a7b1d8f34311c6f
Binary files /dev/null and b/mmpose/utils/__pycache__/logger.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/setup_env.cpython-38.pyc b/mmpose/utils/__pycache__/setup_env.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..351b99a3bf9dedd7be3b45aea0cb6788a084150e
Binary files /dev/null and b/mmpose/utils/__pycache__/setup_env.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc b/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d077c70c227865b72085a6b2e3a4ae38fb9cbf2
Binary files /dev/null and b/mmpose/utils/__pycache__/tensor_utils.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/timer.cpython-38.pyc b/mmpose/utils/__pycache__/timer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95b8a96edd7cb726ae8d36d71cd57377336a60f1
Binary files /dev/null and b/mmpose/utils/__pycache__/timer.cpython-38.pyc differ
diff --git a/mmpose/utils/__pycache__/typing.cpython-38.pyc b/mmpose/utils/__pycache__/typing.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c114174ea50b5264da10b825ece386ae6f1e3ece
Binary files /dev/null and b/mmpose/utils/__pycache__/typing.cpython-38.pyc differ
diff --git a/mmpose/utils/camera.py b/mmpose/utils/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7759d308f38fda99fcf56910b09251d24eccbed
--- /dev/null
+++ b/mmpose/utils/camera.py
@@ -0,0 +1,280 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+import torch
+from mmengine.registry import Registry
+
+CAMERAS = Registry('camera')
+
+
+class SingleCameraBase(metaclass=ABCMeta):
+ """Base class for single camera model.
+
+ Args:
+ param (dict): Camera parameters
+
+ Methods:
+ world_to_camera: Project points from world coordinates to camera
+ coordinates
+ camera_to_world: Project points from camera coordinates to world
+ coordinates
+ camera_to_pixel: Project points from camera coordinates to pixel
+ coordinates
+ world_to_pixel: Project points from world coordinates to pixel
+ coordinates
+ """
+
+ @abstractmethod
+ def __init__(self, param):
+ """Load camera parameters and check validity."""
+
+ def world_to_camera(self, X):
+ """Project points from world coordinates to camera coordinates."""
+ raise NotImplementedError
+
+ def camera_to_world(self, X):
+ """Project points from camera coordinates to world coordinates."""
+ raise NotImplementedError
+
+ def camera_to_pixel(self, X):
+ """Project points from camera coordinates to pixel coordinates."""
+ raise NotImplementedError
+
+ def world_to_pixel(self, X):
+ """Project points from world coordinates to pixel coordinates."""
+ _X = self.world_to_camera(X)
+ return self.camera_to_pixel(_X)
+
+
+@CAMERAS.register_module()
+class SimpleCamera(SingleCameraBase):
+ """Camera model to calculate coordinate transformation with given
+ intrinsic/extrinsic camera parameters.
+
+ Note:
+ The keypoint coordinate should be an np.ndarray with a shape of
+ [...,J, C] where J is the keypoint number of an instance, and C is
+ the coordinate dimension. For example:
+
+ [J, C]: shape of joint coordinates of a person with J joints.
+ [N, J, C]: shape of a batch of person joint coordinates.
+ [N, T, J, C]: shape of a batch of pose sequences.
+
+ Args:
+ param (dict): camera parameters including:
+ - R: 3x3, camera rotation matrix (camera-to-world)
+ - T: 3x1, camera translation (camera-to-world)
+ - K: (optional) 2x3, camera intrinsic matrix
+ - k: (optional) nx1, camera radial distortion coefficients
+ - p: (optional) mx1, camera tangential distortion coefficients
+ - f: (optional) 2x1, camera focal length
+ - c: (optional) 2x1, camera center
+ if K is not provided, it will be calculated from f and c.
+
+ Methods:
+ world_to_camera: Project points from world coordinates to camera
+ coordinates
+ camera_to_pixel: Project points from camera coordinates to pixel
+ coordinates
+ world_to_pixel: Project points from world coordinates to pixel
+ coordinates
+ """
+
+ def __init__(self, param):
+
+ self.param = {}
+ # extrinsic param
+ R = np.array(param['R'], dtype=np.float32)
+ T = np.array(param['T'], dtype=np.float32)
+ assert R.shape == (3, 3)
+ assert T.shape == (3, 1)
+ # The camera matrices are transposed in advance because the joint
+ # coordinates are stored as row vectors.
+ self.param['R_c2w'] = R.T
+ self.param['T_c2w'] = T.T
+ self.param['R_w2c'] = R
+ self.param['T_w2c'] = -self.param['T_c2w'] @ self.param['R_w2c']
+
+ # intrinsic param
+ if 'K' in param:
+ K = np.array(param['K'], dtype=np.float32)
+ assert K.shape == (2, 3)
+ self.param['K'] = K.T
+ self.param['f'] = np.array([K[0, 0], K[1, 1]])[:, np.newaxis]
+ self.param['c'] = np.array([K[0, 2], K[1, 2]])[:, np.newaxis]
+ elif 'f' in param and 'c' in param:
+ f = np.array(param['f'], dtype=np.float32)
+ c = np.array(param['c'], dtype=np.float32)
+ assert f.shape == (2, 1)
+ assert c.shape == (2, 1)
+ self.param['K'] = np.concatenate((np.diagflat(f), c), axis=-1).T
+ self.param['f'] = f
+ self.param['c'] = c
+ else:
+ raise ValueError('Camera intrinsic parameters are missing. '
+ 'Either "K" or "f"&"c" should be provided.')
+
+ # distortion param
+ if 'k' in param and 'p' in param:
+ self.undistortion = True
+ self.param['k'] = np.array(param['k'], dtype=np.float32).flatten()
+ self.param['p'] = np.array(param['p'], dtype=np.float32).flatten()
+ assert self.param['k'].size in {3, 6}
+ assert self.param['p'].size == 2
+ else:
+ self.undistortion = False
+
+ def world_to_camera(self, X):
+ assert isinstance(X, np.ndarray)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+ return X @ self.param['R_w2c'] + self.param['T_w2c']
+
+ def camera_to_world(self, X):
+ assert isinstance(X, np.ndarray)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+ return X @ self.param['R_c2w'] + self.param['T_c2w']
+
+ def camera_to_pixel(self, X):
+ assert isinstance(X, np.ndarray)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+
+ _X = X / X[..., 2:]
+
+ if self.undistortion:
+ k = self.param['k']
+ p = self.param['p']
+ _X_2d = _X[..., :2]
+ r2 = (_X_2d**2).sum(-1)
+ radial = 1 + sum(ki * r2**(i + 1) for i, ki in enumerate(k[:3]))
+ if k.size == 6:
+ radial /= 1 + sum(
+ (ki * r2**(i + 1) for i, ki in enumerate(k[3:])))
+
+ tangential = 2 * (p[1] * _X[..., 0] + p[0] * _X[..., 1])
+
+ _X[..., :2] = _X_2d * (radial + tangential)[..., None] + np.outer(
+ r2, p[::-1]).reshape(_X_2d.shape)
+ return _X @ self.param['K']
+
+ def pixel_to_camera(self, X):
+ assert isinstance(X, np.ndarray)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+ _X = X.copy()
+ _X[:, :2] = (X[:, :2] - self.param['c'].T) / self.param['f'].T * X[:,
+ [2]]
+ return _X
+
+
+@CAMERAS.register_module()
+class SimpleCameraTorch(SingleCameraBase):
+ """Camera model to calculate coordinate transformation with given
+ intrinsic/extrinsic camera parameters.
+
+ Notes:
+ The keypoint coordinate should be an np.ndarray with a shape of
+ [...,J, C] where J is the keypoint number of an instance, and C is
+ the coordinate dimension. For example:
+
+ [J, C]: shape of joint coordinates of a person with J joints.
+ [N, J, C]: shape of a batch of person joint coordinates.
+ [N, T, J, C]: shape of a batch of pose sequences.
+
+ Args:
+ param (dict): camera parameters including:
+ - R: 3x3, camera rotation matrix (camera-to-world)
+ - T: 3x1, camera translation (camera-to-world)
+ - K: (optional) 2x3, camera intrinsic matrix
+ - k: (optional) nx1, camera radial distortion coefficients
+ - p: (optional) mx1, camera tangential distortion coefficients
+ - f: (optional) 2x1, camera focal length
+ - c: (optional) 2x1, camera center
+ if K is not provided, it will be calculated from f and c.
+
+ Methods:
+ world_to_camera: Project points from world coordinates to camera
+ coordinates
+ camera_to_pixel: Project points from camera coordinates to pixel
+ coordinates
+ world_to_pixel: Project points from world coordinates to pixel
+ coordinates
+ """
+
+ def __init__(self, param, device):
+
+ self.param = {}
+ # extrinsic param
+ R = torch.tensor(param['R'], device=device)
+ T = torch.tensor(param['T'], device=device)
+
+ assert R.shape == (3, 3)
+ assert T.shape == (3, 1)
+ # The camera matrices are transposed in advance because the joint
+ # coordinates are stored as row vectors.
+ self.param['R_c2w'] = R.T
+ self.param['T_c2w'] = T.T
+ self.param['R_w2c'] = R
+ self.param['T_w2c'] = -self.param['T_c2w'] @ self.param['R_w2c']
+
+ # intrinsic param
+ if 'K' in param:
+ K = torch.tensor(param['K'], device=device)
+ assert K.shape == (2, 3)
+ self.param['K'] = K.T
+ self.param['f'] = torch.tensor([[K[0, 0]], [K[1, 1]]],
+ device=device)
+ self.param['c'] = torch.tensor([[K[0, 2]], [K[1, 2]]],
+ device=device)
+ elif 'f' in param and 'c' in param:
+ f = torch.tensor(param['f'], device=device)
+ c = torch.tensor(param['c'], device=device)
+ assert f.shape == (2, 1)
+ assert c.shape == (2, 1)
+ self.param['K'] = torch.cat([torch.diagflat(f), c], dim=-1).T
+ self.param['f'] = f
+ self.param['c'] = c
+ else:
+ raise ValueError('Camera intrinsic parameters are missing. '
+ 'Either "K" or "f"&"c" should be provided.')
+
+ # distortion param
+ if 'k' in param and 'p' in param:
+ self.undistortion = True
+ self.param['k'] = torch.tensor(param['k'], device=device).view(-1)
+ self.param['p'] = torch.tensor(param['p'], device=device).view(-1)
+ assert len(self.param['k']) in {3, 6}
+ assert len(self.param['p']) == 2
+ else:
+ self.undistortion = False
+
+ def world_to_camera(self, X):
+ assert isinstance(X, torch.Tensor)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+ return X @ self.param['R_w2c'] + self.param['T_w2c']
+
+ def camera_to_world(self, X):
+ assert isinstance(X, torch.Tensor)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+ return X @ self.param['R_c2w'] + self.param['T_c2w']
+
+ def camera_to_pixel(self, X):
+ assert isinstance(X, torch.Tensor)
+ assert X.ndim >= 2 and X.shape[-1] == 3
+
+ _X = X / X[..., 2:]
+
+ if self.undistortion:
+ k = self.param['k']
+ p = self.param['p']
+ _X_2d = _X[..., :2]
+ r2 = (_X_2d**2).sum(-1)
+ radial = 1 + sum(ki * r2**(i + 1) for i, ki in enumerate(k[:3]))
+ if k.size == 6:
+ radial /= 1 + sum(
+ (ki * r2**(i + 1) for i, ki in enumerate(k[3:])))
+
+ tangential = 2 * (p[1] * _X[..., 0] + p[0] * _X[..., 1])
+
+ _X[..., :2] = _X_2d * (radial + tangential)[..., None] + torch.ger(
+ r2, p.flip([0])).reshape(_X_2d.shape)
+ return _X @ self.param['K']
diff --git a/mmpose/utils/collect_env.py b/mmpose/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8fb5f35e10fe6535b49b7eb7def1459b28835e3
--- /dev/null
+++ b/mmpose/utils/collect_env.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.utils import get_git_hash
+from mmengine.utils.dl_utils import collect_env as collect_base_env
+
+import mmpose
+
+
+def collect_env():
+ env_info = collect_base_env()
+ env_info['MMPose'] = (mmpose.__version__ + '+' + get_git_hash(digits=7))
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
diff --git a/mmpose/utils/config_utils.py b/mmpose/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f54d2ef24093a77933dbf8026465e3cdaf5e839
--- /dev/null
+++ b/mmpose/utils/config_utils.py
@@ -0,0 +1,26 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpose.utils.typing import ConfigDict
+
+
+def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict:
+ """Converts pipeline types in MMDetection's test dataloader to use the
+ 'mmdet' namespace.
+
+ Args:
+ cfg (ConfigDict): Configuration dictionary for MMDetection.
+
+ Returns:
+ ConfigDict: Configuration dictionary with updated pipeline types.
+ """
+ # use lazy import to avoid hard dependence on mmdet
+ from mmdet.datasets import transforms
+
+ if 'test_dataloader' not in cfg:
+ return cfg
+
+ pipeline = cfg.test_dataloader.dataset.pipeline
+ for trans in pipeline:
+ if trans['type'] in dir(transforms):
+ trans['type'] = 'mmdet.' + trans['type']
+
+ return cfg
diff --git a/mmpose/utils/hooks.py b/mmpose/utils/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..b68940f2b7a8a618916ea5aab331e3ce45ba98e7
--- /dev/null
+++ b/mmpose/utils/hooks.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+
+class OutputHook:
+
+ def __init__(self, module, outputs=None, as_tensor=False):
+ self.outputs = outputs
+ self.as_tensor = as_tensor
+ self.layer_outputs = {}
+ self.register(module)
+
+ def register(self, module):
+
+ def hook_wrapper(name):
+
+ def hook(model, input, output):
+ if self.as_tensor:
+ self.layer_outputs[name] = output
+ else:
+ if isinstance(output, list):
+ self.layer_outputs[name] = [
+ out.detach().cpu().numpy() for out in output
+ ]
+ else:
+ self.layer_outputs[name] = output.detach().cpu().numpy(
+ )
+
+ return hook
+
+ self.handles = []
+ if isinstance(self.outputs, (list, tuple)):
+ for name in self.outputs:
+ try:
+ layer = rgetattr(module, name)
+ h = layer.register_forward_hook(hook_wrapper(name))
+ except ModuleNotFoundError as module_not_found:
+ raise ModuleNotFoundError(
+ f'Module {name} not found') from module_not_found
+ self.handles.append(h)
+
+ def remove(self):
+ for h in self.handles:
+ h.remove()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.remove()
+
+
+# using wonder's beautiful simplification:
+# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
+def rgetattr(obj, attr, *args):
+
+ def _getattr(obj, attr):
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
diff --git a/mmpose/utils/logger.py b/mmpose/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f67e56efeb998cf966e3729c90791b4a70f2bb84
--- /dev/null
+++ b/mmpose/utils/logger.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+from mmengine.logging import MMLogger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Use `MMLogger` class in mmengine to get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmpose".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ return MMLogger('MMLogger', __name__.split('.')[0], log_file, log_level)
diff --git a/mmpose/utils/setup_env.py b/mmpose/utils/setup_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff299539ef8cc83a17a24e41498c01ff4f26667f
--- /dev/null
+++ b/mmpose/utils/setup_env.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import platform
+import warnings
+
+import cv2
+import torch.multiprocessing as mp
+from mmengine import DefaultScope
+
+
+def setup_multi_processes(cfg):
+ """Setup multi-processing environment variables."""
+ # set multi-process start method as `fork` to speed up the training
+ if platform.system() != 'Windows':
+ mp_start_method = cfg.get('mp_start_method', 'fork')
+ current_method = mp.get_start_method(allow_none=True)
+ if current_method is not None and current_method != mp_start_method:
+ warnings.warn(
+ f'Multi-processing start method `{mp_start_method}` is '
+ f'different from the previous setting `{current_method}`.'
+ f'It will be force set to `{mp_start_method}`. You can change '
+ f'this behavior by changing `mp_start_method` in your config.')
+ mp.set_start_method(mp_start_method, force=True)
+
+ # disable opencv multithreading to avoid system being overloaded
+ opencv_num_threads = cfg.get('opencv_num_threads', 0)
+ cv2.setNumThreads(opencv_num_threads)
+
+ # setup OMP threads
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
+ if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
+ omp_num_threads = 1
+ warnings.warn(
+ f'Setting OMP_NUM_THREADS environment variable for each process '
+ f'to be {omp_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+ # setup MKL threads
+ if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
+ mkl_num_threads = 1
+ warnings.warn(
+ f'Setting MKL_NUM_THREADS environment variable for each process '
+ f'to be {mkl_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
+
+
+def register_all_modules(init_default_scope: bool = True) -> None:
+ """Register all modules in mmpose into the registries.
+
+ Args:
+ init_default_scope (bool): Whether initialize the mmpose default scope.
+ When `init_default_scope=True`, the global default scope will be
+ set to `mmpose`, and all registries will build modules from mmpose's
+ registry node. To understand more about the registry, please refer
+ to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
+ Defaults to True.
+ """ # noqa
+
+ import mmpose.codecs # noqa: F401, F403
+ import mmpose.datasets # noqa: F401,F403
+ import mmpose.engine # noqa: F401,F403
+ import mmpose.evaluation # noqa: F401,F403
+ import mmpose.models # noqa: F401,F403
+ import mmpose.visualization # noqa: F401,F403
+
+ if init_default_scope:
+ never_created = DefaultScope.get_current_instance() is None \
+ or not DefaultScope.check_instance_created('mmpose')
+ if never_created:
+ DefaultScope.get_instance('mmpose', scope_name='mmpose')
+ return
+ current_scope = DefaultScope.get_current_instance()
+ if current_scope.scope_name != 'mmpose':
+ warnings.warn('The current default scope '
+ f'"{current_scope.scope_name}" is not "mmpose", '
+ '`register_all_modules` will force the current'
+ 'default scope to be "mmpose". If this is not '
+ 'expected, please set `init_default_scope=False`.')
+ # avoid name conflict
+ new_instance_name = f'mmpose-{datetime.datetime.now()}'
+ DefaultScope.get_instance(new_instance_name, scope_name='mmpose')
diff --git a/mmpose/utils/tensor_utils.py b/mmpose/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be73f899178ead4dd4ade9f621aedb07bec4258
--- /dev/null
+++ b/mmpose/utils/tensor_utils.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Optional, Sequence, Union
+
+import numpy as np
+import torch
+from mmengine.utils import is_seq_of
+from torch import Tensor
+
+
+def to_numpy(x: Union[Tensor, Sequence[Tensor]],
+ return_device: bool = False,
+ unzip: bool = False) -> Union[np.ndarray, tuple]:
+ """Convert torch tensor to numpy.ndarray.
+
+ Args:
+ x (Tensor | Sequence[Tensor]): A single tensor or a sequence of
+ tensors
+ return_device (bool): Whether return the tensor device. Defaults to
+ ``False``
+ unzip (bool): Whether unzip the input sequence. Defaults to ``False``
+
+ Returns:
+ np.ndarray | tuple: If ``return_device`` is ``True``, return a tuple
+ of converted numpy array(s) and the device indicator; otherwise only
+ return the numpy array(s)
+ """
+
+ if isinstance(x, Tensor):
+ arrays = x.detach().cpu().numpy()
+ device = x.device
+ elif is_seq_of(x, Tensor):
+ if unzip:
+ # convert (A, B) -> [(A[0], B[0]), (A[1], B[1]), ...]
+ arrays = [
+ tuple(to_numpy(_x[None, :]) for _x in _each)
+ for _each in zip(*x)
+ ]
+ else:
+ arrays = [to_numpy(_x) for _x in x]
+
+ device = x[0].device
+
+ else:
+ raise ValueError(f'Invalid input type {type(x)}')
+
+ if return_device:
+ return arrays, device
+ else:
+ return arrays
+
+
+def to_tensor(x: Union[np.ndarray, Sequence[np.ndarray]],
+ device: Optional[Any] = None) -> Union[Tensor, Sequence[Tensor]]:
+ """Convert numpy.ndarray to torch tensor.
+
+ Args:
+ x (np.ndarray | Sequence[np.ndarray]): A single np.ndarray or a
+ sequence of tensors
+ tensor (Any, optional): The device indicator. Defaults to ``None``
+
+ Returns:
+ tuple:
+ - Tensor | Sequence[Tensor]: The converted Tensor or Tensor sequence
+ """
+ if isinstance(x, np.ndarray):
+ return torch.tensor(x, device=device)
+ elif is_seq_of(x, np.ndarray):
+ return [to_tensor(_x, device=device) for _x in x]
+ else:
+ raise ValueError(f'Invalid input type {type(x)}')
diff --git a/mmpose/utils/timer.py b/mmpose/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c219c04069d239605a7854b06a370876dbe8fd58
--- /dev/null
+++ b/mmpose/utils/timer.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import defaultdict
+from contextlib import contextmanager
+from functools import partial
+
+import numpy as np
+from mmengine import Timer
+
+
+class RunningAverage():
+ r"""A helper class to calculate running average in a sliding window.
+
+ Args:
+ window (int): The size of the sliding window.
+ """
+
+ def __init__(self, window: int = 1):
+ self.window = window
+ self._data = []
+
+ def update(self, value):
+ """Update a new data sample."""
+ self._data.append(value)
+ self._data = self._data[-self.window:]
+
+ def average(self):
+ """Get the average value of current window."""
+ return np.mean(self._data)
+
+
+class StopWatch:
+ r"""A helper class to measure FPS and detailed time consuming of each phase
+ in a video processing loop or similar scenarios.
+
+ Args:
+ window (int): The sliding window size to calculate the running average
+ of the time consuming.
+
+ Example:
+ >>> from mmpose.utils import StopWatch
+ >>> import time
+ >>> stop_watch = StopWatch(window=10)
+ >>> with stop_watch.timeit('total'):
+ >>> time.sleep(0.1)
+ >>> # 'timeit' support nested use
+ >>> with stop_watch.timeit('phase1'):
+ >>> time.sleep(0.1)
+ >>> with stop_watch.timeit('phase2'):
+ >>> time.sleep(0.2)
+ >>> time.sleep(0.2)
+ >>> report = stop_watch.report()
+ """
+
+ def __init__(self, window=1):
+ self.window = window
+ self._record = defaultdict(partial(RunningAverage, window=self.window))
+ self._timer_stack = []
+
+ @contextmanager
+ def timeit(self, timer_name='_FPS_'):
+ """Timing a code snippet with an assigned name.
+
+ Args:
+ timer_name (str): The unique name of the interested code snippet to
+ handle multiple timers and generate reports. Note that '_FPS_'
+ is a special key that the measurement will be in `fps` instead
+ of `millisecond`. Also see `report` and `report_strings`.
+ Default: '_FPS_'.
+ Note:
+ This function should always be used in a `with` statement, as shown
+ in the example.
+ """
+ self._timer_stack.append((timer_name, Timer()))
+ try:
+ yield
+ finally:
+ timer_name, timer = self._timer_stack.pop()
+ self._record[timer_name].update(timer.since_start())
+
+ def report(self, key=None):
+ """Report timing information.
+
+ Returns:
+ dict: The key is the timer name and the value is the \
+ corresponding average time consuming.
+ """
+ result = {
+ name: r.average() * 1000.
+ for name, r in self._record.items()
+ }
+
+ if '_FPS_' in result:
+ result['_FPS_'] = 1000. / result.pop('_FPS_')
+
+ if key is None:
+ return result
+ return result[key]
+
+ def report_strings(self):
+ """Report timing information in texture strings.
+
+ Returns:
+ list(str): Each element is the information string of a timed \
+ event, in format of '{timer_name}: {time_in_ms}'. \
+ Specially, if timer_name is '_FPS_', the result will \
+ be converted to fps.
+ """
+ result = self.report()
+ strings = []
+ if '_FPS_' in result:
+ strings.append(f'FPS: {result["_FPS_"]:>5.1f}')
+ strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
+ return strings
+
+ def reset(self):
+ self._record = defaultdict(list)
+ self._active_timer_stack = []
diff --git a/mmpose/utils/typing.py b/mmpose/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..557891b3b92e657de43eb50d4b5fbce7d369e7ee
--- /dev/null
+++ b/mmpose/utils/typing.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Tuple, Union
+
+from mmengine.config import ConfigDict
+from mmengine.structures import InstanceData, PixelData
+from torch import Tensor
+
+from mmpose.structures import PoseDataSample
+
+# Type hint of config data
+ConfigType = Union[ConfigDict, dict]
+OptConfigType = Optional[ConfigType]
+# Type hint of one or more config data
+MultiConfig = Union[ConfigType, List[ConfigType]]
+OptMultiConfig = Optional[MultiConfig]
+# Type hint of data samples
+SampleList = List[PoseDataSample]
+OptSampleList = Optional[SampleList]
+InstanceList = List[InstanceData]
+PixelDataList = List[PixelData]
+Predictions = Union[InstanceList, Tuple[InstanceList, PixelDataList]]
+# Type hint of model outputs
+ForwardResults = Union[Dict[str, Tensor], List[PoseDataSample], Tuple[Tensor],
+ Tensor]
+# Type hint of features
+# - Tuple[Tensor]: multi-level features extracted by the network
+# - List[Tuple[Tensor]]: multiple feature pyramids for TTA
+# - List[List[Tuple[Tensor]]]: multi-scale feature pyramids
+Features = Union[Tuple[Tensor], List[Tuple[Tensor]], List[List[Tuple[Tensor]]]]
diff --git a/mmpose/version.py b/mmpose/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..73312cc28dbe08a076c725e0b9a6ea8579db9d1c
--- /dev/null
+++ b/mmpose/version.py
@@ -0,0 +1,31 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+
+__version__ = '1.0.0'
+short_version = __version__
+
+
+def parse_version_info(version_str):
+ """Parse a version string into a tuple.
+
+ Args:
+ version_str (str): The version string.
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
+ """
+ version_info = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ version_info.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ version_info.append(int(patch_version[0]))
+ version_info.append(f'rc{patch_version[1]}')
+ elif x.find('b') != -1:
+ patch_version = x.split('b')
+ version_info.append(int(patch_version[0]))
+ version_info.append(f'b{patch_version[1]}')
+ return tuple(version_info)
+
+
+version_info = parse_version_info(__version__)
diff --git a/mmpose/visualization/__init__.py b/mmpose/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..357d40a707bd5e87d65fc4236233658dd1aab18d
--- /dev/null
+++ b/mmpose/visualization/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .local_visualizer import PoseLocalVisualizer
+
+__all__ = ['PoseLocalVisualizer']
diff --git a/mmpose/visualization/__pycache__/__init__.cpython-38.pyc b/mmpose/visualization/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..479042786d87f1d1c5e92f2b0f49d44d03c63d7e
Binary files /dev/null and b/mmpose/visualization/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc b/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63fbedc3689d09829c2f9471b5768ce68a97c9ae
Binary files /dev/null and b/mmpose/visualization/__pycache__/local_visualizer.cpython-38.pyc differ
diff --git a/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc b/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..399cb790289114ac8431e78c36c6989f53f07f8e
Binary files /dev/null and b/mmpose/visualization/__pycache__/opencv_backend_visualizer.cpython-38.pyc differ
diff --git a/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc b/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7b3e5367f695d1d2bfd6e2253e41ef6aa623b1e
Binary files /dev/null and b/mmpose/visualization/__pycache__/simcc_vis.cpython-38.pyc differ
diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..205993c006eb5349c0dc371e7d4124b19050cf2f
--- /dev/null
+++ b/mmpose/visualization/local_visualizer.py
@@ -0,0 +1,583 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from mmengine.dist import master_only
+from mmengine.structures import InstanceData, PixelData
+
+from mmpose.datasets.datasets.utils import parse_pose_metainfo
+from mmpose.registry import VISUALIZERS
+from mmpose.structures import PoseDataSample
+from .opencv_backend_visualizer import OpencvBackendVisualizer
+from .simcc_vis import SimCCVisualizer
+
+
+def _get_adaptive_scales(areas: np.ndarray,
+ min_area: int = 800,
+ max_area: int = 30000) -> np.ndarray:
+ """Get adaptive scales according to areas.
+
+ The scale range is [0.5, 1.0]. When the area is less than
+ ``min_area``, the scale is 0.5 while the area is larger than
+ ``max_area``, the scale is 1.0.
+
+ Args:
+ areas (ndarray): The areas of bboxes or masks with the
+ shape of (n, ).
+ min_area (int): Lower bound areas for adaptive scales.
+ Defaults to 800.
+ max_area (int): Upper bound areas for adaptive scales.
+ Defaults to 30000.
+
+ Returns:
+ ndarray: The adaotive scales with the shape of (n, ).
+ """
+ scales = 0.5 + (areas - min_area) / (max_area - min_area)
+ scales = np.clip(scales, 0.5, 1.0)
+ return scales
+
+
+@VISUALIZERS.register_module()
+class PoseLocalVisualizer(OpencvBackendVisualizer):
+ """MMPose Local Visualizer.
+
+ Args:
+ name (str): Name of the instance. Defaults to 'visualizer'.
+ image (np.ndarray, optional): the origin image to draw. The format
+ should be RGB. Defaults to ``None``
+ vis_backends (list, optional): Visual backend config list. Defaults to
+ ``None``
+ save_dir (str, optional): Save file dir for all storage backends.
+ If it is ``None``, the backend storage will not save any data.
+ Defaults to ``None``
+ bbox_color (str, tuple(int), optional): Color of bbox lines.
+ The tuple of color should be in BGR order. Defaults to ``'green'``
+ kpt_color (str, tuple(tuple(int)), optional): Color of keypoints.
+ The tuple of color should be in BGR order. Defaults to ``'red'``
+ link_color (str, tuple(tuple(int)), optional): Color of skeleton.
+ The tuple of color should be in BGR order. Defaults to ``None``
+ line_width (int, float): The width of lines. Defaults to 1
+ radius (int, float): The radius of keypoints. Defaults to 4
+ show_keypoint_weight (bool): Whether to adjust the transparency
+ of keypoints according to their score. Defaults to ``False``
+ alpha (int, float): The transparency of bboxes. Defaults to ``0.8``
+
+ Examples:
+ >>> import numpy as np
+ >>> from mmengine.structures import InstanceData
+ >>> from mmpose.structures import PoseDataSample
+ >>> from mmpose.visualization import PoseLocalVisualizer
+
+ >>> pose_local_visualizer = PoseLocalVisualizer(radius=1)
+ >>> image = np.random.randint(0, 256,
+ ... size=(10, 12, 3)).astype('uint8')
+ >>> gt_instances = InstanceData()
+ >>> gt_instances.keypoints = np.array([[[1, 1], [2, 2], [4, 4],
+ ... [8, 8]]])
+ >>> gt_pose_data_sample = PoseDataSample()
+ >>> gt_pose_data_sample.gt_instances = gt_instances
+ >>> dataset_meta = {'skeleton_links': [[0, 1], [1, 2], [2, 3]]}
+ >>> pose_local_visualizer.set_dataset_meta(dataset_meta)
+ >>> pose_local_visualizer.add_datasample('image', image,
+ ... gt_pose_data_sample)
+ >>> pose_local_visualizer.add_datasample(
+ ... 'image', image, gt_pose_data_sample,
+ ... out_file='out_file.jpg')
+ >>> pose_local_visualizer.add_datasample(
+ ... 'image', image, gt_pose_data_sample,
+ ... show=True)
+ >>> pred_instances = InstanceData()
+ >>> pred_instances.keypoints = np.array([[[1, 1], [2, 2], [4, 4],
+ ... [8, 8]]])
+ >>> pred_instances.score = np.array([0.8, 1, 0.9, 1])
+ >>> pred_pose_data_sample = PoseDataSample()
+ >>> pred_pose_data_sample.pred_instances = pred_instances
+ >>> pose_local_visualizer.add_datasample('image', image,
+ ... gt_pose_data_sample,
+ ... pred_pose_data_sample)
+ """
+
+ def __init__(self,
+ name: str = 'visualizer',
+ image: Optional[np.ndarray] = None,
+ vis_backends: Optional[Dict] = None,
+ save_dir: Optional[str] = None,
+ bbox_color: Optional[Union[str, Tuple[int]]] = 'green',
+ kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = 'red',
+ link_color: Optional[Union[str, Tuple[Tuple[int]]]] = None,
+ text_color: Optional[Union[str,
+ Tuple[int]]] = (255, 255, 255),
+ skeleton: Optional[Union[List, Tuple]] = None,
+ line_width: Union[int, float] = 1,
+ radius: Union[int, float] = 3,
+ show_keypoint_weight: bool = False,
+ backend: str = 'opencv',
+ alpha: float = 0.8):
+ super().__init__(
+ name=name,
+ image=image,
+ vis_backends=vis_backends,
+ save_dir=save_dir,
+ backend=backend)
+
+ self.bbox_color = bbox_color
+ self.kpt_color = kpt_color
+ self.link_color = link_color
+ self.line_width = line_width
+ self.text_color = text_color
+ self.skeleton = skeleton
+ self.radius = radius
+ self.alpha = alpha
+ self.show_keypoint_weight = show_keypoint_weight
+ # Set default value. When calling
+ # `PoseLocalVisualizer().set_dataset_meta(xxx)`,
+ # it will override the default value.
+ self.dataset_meta = {}
+
+ def set_dataset_meta(self,
+ dataset_meta: Dict,
+ skeleton_style: str = 'mmpose'):
+ """Assign dataset_meta to the visualizer. The default visualization
+ settings will be overridden.
+
+ Args:
+ dataset_meta (dict): meta information of dataset.
+ """
+ if dataset_meta.get(
+ 'dataset_name') == 'coco' and skeleton_style == 'openpose':
+ dataset_meta = parse_pose_metainfo(
+ dict(from_file='configs/_base_/datasets/coco_openpose.py'))
+
+ if isinstance(dataset_meta, dict):
+ self.dataset_meta = dataset_meta.copy()
+ self.bbox_color = dataset_meta.get('bbox_color', self.bbox_color)
+ self.kpt_color = dataset_meta.get('keypoint_colors',
+ self.kpt_color)
+ self.link_color = dataset_meta.get('skeleton_link_colors',
+ self.link_color)
+ self.skeleton = dataset_meta.get('skeleton_links', self.skeleton)
+ # sometimes self.dataset_meta is manually set, which might be None.
+ # it should be converted to a dict at these times
+ if self.dataset_meta is None:
+ self.dataset_meta = {}
+
+ def _draw_instances_bbox(self, image: np.ndarray,
+ instances: InstanceData) -> np.ndarray:
+ """Draw bounding boxes and corresponding labels of GT or prediction.
+
+ Args:
+ image (np.ndarray): The image to draw.
+ instances (:obj:`InstanceData`): Data structure for
+ instance-level annotations or predictions.
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+ self.set_image(image)
+
+ if 'bboxes' in instances:
+ bboxes = instances.bboxes
+ self.draw_bboxes(
+ bboxes,
+ edge_colors=self.bbox_color,
+ alpha=self.alpha,
+ line_widths=self.line_width)
+ else:
+ return self.get_image()
+
+ if 'labels' in instances and self.text_color is not None:
+ classes = self.dataset_meta.get('classes', None)
+ labels = instances.labels
+
+ positions = bboxes[:, :2]
+ areas = (bboxes[:, 3] - bboxes[:, 1]) * (
+ bboxes[:, 2] - bboxes[:, 0])
+ scales = _get_adaptive_scales(areas)
+
+ for i, (pos, label) in enumerate(zip(positions, labels)):
+ label_text = classes[
+ label] if classes is not None else f'class {label}'
+
+ if isinstance(self.bbox_color,
+ tuple) and max(self.bbox_color) > 1:
+ facecolor = [c / 255.0 for c in self.bbox_color]
+ else:
+ facecolor = self.bbox_color
+
+ self.draw_texts(
+ label_text,
+ pos,
+ colors=self.text_color,
+ font_sizes=int(13 * scales[i]),
+ vertical_alignments='bottom',
+ bboxes=[{
+ 'facecolor': facecolor,
+ 'alpha': 0.8,
+ 'pad': 0.7,
+ 'edgecolor': 'none'
+ }])
+
+ return self.get_image()
+
+ def _draw_instances_kpts(self,
+ image: np.ndarray,
+ instances: InstanceData,
+ kpt_thr: float = 0.3,
+ show_kpt_idx: bool = False,
+ skeleton_style: str = 'mmpose'):
+ """Draw keypoints and skeletons (optional) of GT or prediction.
+
+ Args:
+ image (np.ndarray): The image to draw.
+ instances (:obj:`InstanceData`): Data structure for
+ instance-level annotations or predictions.
+ kpt_thr (float, optional): Minimum threshold of keypoints
+ to be shown. Default: 0.3.
+ show_kpt_idx (bool): Whether to show the index of keypoints.
+ Defaults to ``False``
+ skeleton_style (str): Skeleton style selection. Defaults to
+ ``'mmpose'``
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+
+ self.set_image(image)
+ img_h, img_w, _ = image.shape
+
+ if 'keypoints' in instances:
+ keypoints = instances.get('transformed_keypoints',
+ instances.keypoints)
+
+ if 'keypoint_scores' in instances:
+ scores = instances.keypoint_scores
+ else:
+ scores = np.ones(keypoints.shape[:-1])
+
+ if 'keypoints_visible' in instances:
+ keypoints_visible = instances.keypoints_visible
+ else:
+ keypoints_visible = np.ones(keypoints.shape[:-1])
+
+ if skeleton_style == 'openpose':
+ keypoints_info = np.concatenate(
+ (keypoints, scores[..., None], keypoints_visible[...,
+ None]),
+ axis=-1)
+ # compute neck joint
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
+ # neck score when visualizing pred
+ neck[:, 2:4] = np.logical_and(
+ keypoints_info[:, 5, 2:4] > kpt_thr,
+ keypoints_info[:, 6, 2:4] > kpt_thr).astype(int)
+ new_keypoints_info = np.insert(
+ keypoints_info, 17, neck, axis=1)
+
+ mmpose_idx = [
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
+ ]
+ openpose_idx = [
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
+ ]
+ new_keypoints_info[:, openpose_idx] = \
+ new_keypoints_info[:, mmpose_idx]
+ keypoints_info = new_keypoints_info
+
+ keypoints, scores, keypoints_visible = keypoints_info[
+ ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
+
+ for kpts, score, visible in zip(keypoints, scores,
+ keypoints_visible):
+ kpts = np.array(kpts, copy=False)
+
+ if self.kpt_color is None or isinstance(self.kpt_color, str):
+ kpt_color = [self.kpt_color] * len(kpts)
+ elif len(self.kpt_color) == len(kpts):
+ kpt_color = self.kpt_color
+ else:
+ raise ValueError(
+ f'the length of kpt_color '
+ f'({len(self.kpt_color)}) does not matches '
+ f'that of keypoints ({len(kpts)})')
+
+ # draw links
+ if self.skeleton is not None and self.link_color is not None:
+ if self.link_color is None or isinstance(
+ self.link_color, str):
+ link_color = [self.link_color] * len(self.skeleton)
+ elif len(self.link_color) == len(self.skeleton):
+ link_color = self.link_color
+ else:
+ raise ValueError(
+ f'the length of link_color '
+ f'({len(self.link_color)}) does not matches '
+ f'that of skeleton ({len(self.skeleton)})')
+
+ for sk_id, sk in enumerate(self.skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+ if not (visible[sk[0]] and visible[sk[1]]):
+ continue
+
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0
+ or pos1[1] >= img_h or pos2[0] <= 0
+ or pos2[0] >= img_w or pos2[1] <= 0
+ or pos2[1] >= img_h or score[sk[0]] < kpt_thr
+ or score[sk[1]] < kpt_thr
+ or link_color[sk_id] is None):
+ # skip the link that should not be drawn
+ continue
+ X = np.array((pos1[0], pos2[0]))
+ Y = np.array((pos1[1], pos2[1]))
+ color = link_color[sk_id]
+ if not isinstance(color, str):
+ color = tuple(int(c) for c in color)
+ transparency = self.alpha
+ if self.show_keypoint_weight:
+ transparency *= max(
+ 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]])))
+
+ if skeleton_style == 'openpose':
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5
+ angle = math.degrees(
+ math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ stickwidth = 2
+ polygons = cv2.ellipse2Poly(
+ (int(mX), int(mY)),
+ (int(length / 2), int(stickwidth)), int(angle),
+ 0, 360, 1)
+
+ self.draw_polygons(
+ polygons,
+ edge_colors=color,
+ face_colors=color,
+ alpha=transparency)
+
+ else:
+ self.draw_lines(
+ X, Y, color, line_widths=self.line_width)
+
+ # draw each point on image
+ for kid, kpt in enumerate(kpts):
+ if score[kid] < kpt_thr or not visible[
+ kid] or kpt_color[kid] is None:
+ # skip the point that should not be drawn
+ continue
+
+ color = kpt_color[kid]
+ if not isinstance(color, str):
+ color = tuple(int(c) for c in color)
+ transparency = self.alpha
+ if self.show_keypoint_weight:
+ transparency *= max(0, min(1, score[kid]))
+ self.draw_circles(
+ kpt,
+ radius=np.array([self.radius]),
+ face_colors=color,
+ edge_colors=color,
+ alpha=transparency,
+ line_widths=self.radius)
+ if show_kpt_idx:
+ kpt[0] += self.radius
+ kpt[1] -= self.radius
+ self.draw_texts(
+ str(kid),
+ kpt,
+ colors=color,
+ font_sizes=self.radius * 3,
+ vertical_alignments='bottom',
+ horizontal_alignments='center')
+
+ return self.get_image()
+
+ def _draw_instance_heatmap(
+ self,
+ fields: PixelData,
+ overlaid_image: Optional[np.ndarray] = None,
+ ):
+ """Draw heatmaps of GT or prediction.
+
+ Args:
+ fields (:obj:`PixelData`): Data structure for
+ pixel-level annotations or predictions.
+ overlaid_image (np.ndarray): The image to draw.
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+ if 'heatmaps' not in fields:
+ return None
+ heatmaps = fields.heatmaps
+ if isinstance(heatmaps, np.ndarray):
+ heatmaps = torch.from_numpy(heatmaps)
+ if heatmaps.dim() == 3:
+ heatmaps, _ = heatmaps.max(dim=0)
+ heatmaps = heatmaps.unsqueeze(0)
+ out_image = self.draw_featmap(heatmaps, overlaid_image)
+ return out_image
+
+ def _draw_instance_xy_heatmap(
+ self,
+ fields: PixelData,
+ overlaid_image: Optional[np.ndarray] = None,
+ n: int = 20,
+ ):
+ """Draw heatmaps of GT or prediction.
+
+ Args:
+ fields (:obj:`PixelData`): Data structure for
+ pixel-level annotations or predictions.
+ overlaid_image (np.ndarray): The image to draw.
+ n (int): Number of keypoint, up to 20.
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+ if 'heatmaps' not in fields:
+ return None
+ heatmaps = fields.heatmaps
+ _, h, w = heatmaps.shape
+ if isinstance(heatmaps, np.ndarray):
+ heatmaps = torch.from_numpy(heatmaps)
+ out_image = SimCCVisualizer().draw_instance_xy_heatmap(
+ heatmaps, overlaid_image, n)
+ out_image = cv2.resize(out_image[:, :, ::-1], (w, h))
+ return out_image
+
+ @master_only
+ def add_datasample(self,
+ name: str,
+ image: np.ndarray,
+ data_sample: PoseDataSample,
+ draw_gt: bool = True,
+ draw_pred: bool = True,
+ draw_heatmap: bool = False,
+ draw_bbox: bool = False,
+ show_kpt_idx: bool = False,
+ skeleton_style: str = 'mmpose',
+ show: bool = False,
+ wait_time: float = 0,
+ out_file: Optional[str] = None,
+ kpt_thr: float = 0.3,
+ step: int = 0) -> None:
+ """Draw datasample and save to all backends.
+
+ - If GT and prediction are plotted at the same time, they are
+ displayed in a stitched image where the left image is the
+ ground truth and the right image is the prediction.
+ - If ``show`` is True, all storage backends are ignored, and
+ the images will be displayed in a local window.
+ - If ``out_file`` is specified, the drawn image will be
+ saved to ``out_file``. t is usually used when the display
+ is not available.
+
+ Args:
+ name (str): The image identifier
+ image (np.ndarray): The image to draw
+ data_sample (:obj:`PoseDataSample`, optional): The data sample
+ to visualize
+ draw_gt (bool): Whether to draw GT PoseDataSample. Default to
+ ``True``
+ draw_pred (bool): Whether to draw Prediction PoseDataSample.
+ Defaults to ``True``
+ draw_bbox (bool): Whether to draw bounding boxes. Default to
+ ``False``
+ draw_heatmap (bool): Whether to draw heatmaps. Defaults to
+ ``False``
+ show_kpt_idx (bool): Whether to show the index of keypoints.
+ Defaults to ``False``
+ skeleton_style (str): Skeleton style selection. Defaults to
+ ``'mmpose'``
+ show (bool): Whether to display the drawn image. Default to
+ ``False``
+ wait_time (float): The interval of show (s). Defaults to 0
+ out_file (str): Path to output file. Defaults to ``None``
+ kpt_thr (float, optional): Minimum threshold of keypoints
+ to be shown. Default: 0.3.
+ step (int): Global step value to record. Defaults to 0
+ """
+
+ gt_img_data = None
+ pred_img_data = None
+
+ if draw_gt:
+ gt_img_data = image.copy()
+ gt_img_heatmap = None
+
+ # draw bboxes & keypoints
+ if 'gt_instances' in data_sample:
+ gt_img_data = self._draw_instances_kpts(
+ gt_img_data, data_sample.gt_instances, kpt_thr,
+ show_kpt_idx, skeleton_style)
+ if draw_bbox:
+ gt_img_data = self._draw_instances_bbox(
+ gt_img_data, data_sample.gt_instances)
+
+ # draw heatmaps
+ if 'gt_fields' in data_sample and draw_heatmap:
+ gt_img_heatmap = self._draw_instance_heatmap(
+ data_sample.gt_fields, image)
+ if gt_img_heatmap is not None:
+ gt_img_data = np.concatenate((gt_img_data, gt_img_heatmap),
+ axis=0)
+
+ if draw_pred:
+ pred_img_data = image.copy()
+ pred_img_heatmap = None
+
+ # draw bboxes & keypoints
+ if 'pred_instances' in data_sample:
+ pred_img_data = self._draw_instances_kpts(
+ pred_img_data, data_sample.pred_instances, kpt_thr,
+ show_kpt_idx, skeleton_style)
+ if draw_bbox:
+ pred_img_data = self._draw_instances_bbox(
+ pred_img_data, data_sample.pred_instances)
+
+ # draw heatmaps
+ if 'pred_fields' in data_sample and draw_heatmap:
+ if 'keypoint_x_labels' in data_sample.pred_instances:
+ pred_img_heatmap = self._draw_instance_xy_heatmap(
+ data_sample.pred_fields, image)
+ else:
+ pred_img_heatmap = self._draw_instance_heatmap(
+ data_sample.pred_fields, image)
+ if pred_img_heatmap is not None:
+ pred_img_data = np.concatenate(
+ (pred_img_data, pred_img_heatmap), axis=0)
+
+ # merge visualization results
+ if gt_img_data is not None and pred_img_data is not None:
+ if gt_img_heatmap is None and pred_img_heatmap is not None:
+ gt_img_data = np.concatenate((gt_img_data, image), axis=0)
+ elif gt_img_heatmap is not None and pred_img_heatmap is None:
+ pred_img_data = np.concatenate((pred_img_data, image), axis=0)
+
+ drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
+
+ elif gt_img_data is not None:
+ drawn_img = gt_img_data
+ else:
+ drawn_img = pred_img_data
+
+ # It is convenient for users to obtain the drawn image.
+ # For example, the user wants to obtain the drawn image and
+ # save it as a video during video inference.
+ self.set_image(drawn_img)
+
+ if show:
+ self.show(drawn_img, win_name=name, wait_time=wait_time)
+
+ if out_file is not None:
+ mmcv.imwrite(drawn_img[..., ::-1], out_file)
+ else:
+ # save drawn_img to backends
+ self.add_image(name, drawn_img, step)
+
+ return self.get_image()
diff --git a/mmpose/visualization/opencv_backend_visualizer.py b/mmpose/visualization/opencv_backend_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a7731c76e19a8e7f719a24748b64f316326084
--- /dev/null
+++ b/mmpose/visualization/opencv_backend_visualizer.py
@@ -0,0 +1,444 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from mmengine.dist import master_only
+from mmengine.visualization import Visualizer
+
+
+class OpencvBackendVisualizer(Visualizer):
+ """Base visualizer with opencv backend support.
+
+ Args:
+ name (str): Name of the instance. Defaults to 'visualizer'.
+ image (np.ndarray, optional): the origin image to draw. The format
+ should be RGB. Defaults to None.
+ vis_backends (list, optional): Visual backend config list.
+ Defaults to None.
+ save_dir (str, optional): Save file dir for all storage backends.
+ If it is None, the backend storage will not save any data.
+ fig_save_cfg (dict): Keyword parameters of figure for saving.
+ Defaults to empty dict.
+ fig_show_cfg (dict): Keyword parameters of figure for showing.
+ Defaults to empty dict.
+ backend (str): Backend used to draw elements on the image and display
+ the image. Defaults to 'matplotlib'.
+ """
+
+ def __init__(self,
+ name='visualizer',
+ backend: str = 'matplotlib',
+ *args,
+ **kwargs):
+ super().__init__(name, *args, **kwargs)
+ assert backend in ('opencv', 'matplotlib'), f'the argument ' \
+ f'\'backend\' must be either \'opencv\' or \'matplotlib\', ' \
+ f'but got \'{backend}\'.'
+ self.backend = backend
+
+ @master_only
+ def set_image(self, image: np.ndarray) -> None:
+ """Set the image to draw.
+
+ Args:
+ image (np.ndarray): The image to draw.
+ backend (str): The backend to save the image.
+ """
+ assert image is not None
+ image = image.astype('uint8')
+ self._image = image
+ self.width, self.height = image.shape[1], image.shape[0]
+ self._default_font_size = max(
+ np.sqrt(self.height * self.width) // 90, 10)
+
+ if self.backend == 'matplotlib':
+ # add a small 1e-2 to avoid precision lost due to matplotlib's
+ # truncation (https://github.com/matplotlib/matplotlib/issues/15363) # noqa
+ self.fig_save.set_size_inches( # type: ignore
+ (self.width + 1e-2) / self.dpi,
+ (self.height + 1e-2) / self.dpi)
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
+ self.ax_save.cla()
+ self.ax_save.axis(False)
+ self.ax_save.imshow(
+ image,
+ extent=(0, self.width, self.height, 0),
+ interpolation='none')
+
+ @master_only
+ def get_image(self) -> np.ndarray:
+ """Get the drawn image. The format is RGB.
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+ assert self._image is not None, 'Please set image using `set_image`'
+ if self.backend == 'matplotlib':
+ return super().get_image()
+ else:
+ return self._image
+
+ @master_only
+ def draw_circles(self,
+ center: Union[np.ndarray, torch.Tensor],
+ radius: Union[np.ndarray, torch.Tensor],
+ face_colors: Union[str, tuple, List[str],
+ List[tuple]] = 'none',
+ **kwargs) -> 'Visualizer':
+ """Draw single or multiple circles.
+
+ Args:
+ center (Union[np.ndarray, torch.Tensor]): The x coordinate of
+ each line' start and end points.
+ radius (Union[np.ndarray, torch.Tensor]): The y coordinate of
+ each line' start and end points.
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of circles. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value,
+ all the lines will have the same colors. Reference to
+ https://matplotlib.org/stable/gallery/color/named_colors.html
+ for more details. Defaults to 'g.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
+ value, all the lines will have the same linestyle.
+ Reference to
+ https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
+ for more details. Defaults to '-'.
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
+ the same length with lines or just single value.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
+ face_colors (Union[str, tuple, List[str], List[tuple]]):
+ The face colors. Defaults to None.
+ alpha (Union[int, float]): The transparency of circles.
+ Defaults to 0.8.
+ """
+ if self.backend == 'matplotlib':
+ super().draw_circles(
+ center=center,
+ radius=radius,
+ face_colors=face_colors,
+ **kwargs)
+ elif self.backend == 'opencv':
+ if isinstance(face_colors, str):
+ face_colors = mmcv.color_val(face_colors)
+ self._image = cv2.circle(self._image,
+ (int(center[0]), int(center[1])),
+ int(radius), face_colors, -1)
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
+
+ @master_only
+ def draw_texts(
+ self,
+ texts: Union[str, List[str]],
+ positions: Union[np.ndarray, torch.Tensor],
+ font_sizes: Optional[Union[int, List[int]]] = None,
+ colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ vertical_alignments: Union[str, List[str]] = 'top',
+ horizontal_alignments: Union[str, List[str]] = 'left',
+ bboxes: Optional[Union[dict, List[dict]]] = None,
+ **kwargs,
+ ) -> 'Visualizer':
+ """Draw single or multiple text boxes.
+
+ Args:
+ texts (Union[str, List[str]]): Texts to draw.
+ positions (Union[np.ndarray, torch.Tensor]): The position to draw
+ the texts, which should have the same length with texts and
+ each dim contain x and y.
+ font_sizes (Union[int, List[int]], optional): The font size of
+ texts. ``font_sizes`` can have the same length with texts or
+ just single value. If ``font_sizes`` is single value, all the
+ texts will have the same font size. Defaults to None.
+ colors (Union[str, tuple, List[str], List[tuple]]): The colors
+ of texts. ``colors`` can have the same length with texts or
+ just single value. If ``colors`` is single value, all the
+ texts will have the same colors. Reference to
+ https://matplotlib.org/stable/gallery/color/named_colors.html
+ for more details. Defaults to 'g.
+ vertical_alignments (Union[str, List[str]]): The verticalalignment
+ of texts. verticalalignment controls whether the y positional
+ argument for the text indicates the bottom, center or top side
+ of the text bounding box.
+ ``vertical_alignments`` can have the same length with
+ texts or just single value. If ``vertical_alignments`` is
+ single value, all the texts will have the same
+ verticalalignment. verticalalignment can be 'center' or
+ 'top', 'bottom' or 'baseline'. Defaults to 'top'.
+ horizontal_alignments (Union[str, List[str]]): The
+ horizontalalignment of texts. Horizontalalignment controls
+ whether the x positional argument for the text indicates the
+ left, center or right side of the text bounding box.
+ ``horizontal_alignments`` can have
+ the same length with texts or just single value.
+ If ``horizontal_alignments`` is single value, all the texts
+ will have the same horizontalalignment. Horizontalalignment
+ can be 'center','right' or 'left'. Defaults to 'left'.
+ font_families (Union[str, List[str]]): The font family of
+ texts. ``font_families`` can have the same length with texts or
+ just single value. If ``font_families`` is single value, all
+ the texts will have the same font family.
+ font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy'
+ or 'monospace'. Defaults to 'sans-serif'.
+ bboxes (Union[dict, List[dict]], optional): The bounding box of the
+ texts. If bboxes is None, there are no bounding box around
+ texts. ``bboxes`` can have the same length with texts or
+ just single value. If ``bboxes`` is single value, all
+ the texts will have the same bbox. Reference to
+ https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch
+ for more details. Defaults to None.
+ font_properties (Union[FontProperties, List[FontProperties]], optional):
+ The font properties of texts. FontProperties is
+ a ``font_manager.FontProperties()`` object.
+ If you want to draw Chinese texts, you need to prepare
+ a font file that can show Chinese characters properly.
+ For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on.
+ Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')``
+ ``font_properties`` can have the same length with texts or
+ just single value. If ``font_properties`` is single value,
+ all the texts will have the same font properties.
+ Defaults to None.
+ `New in version 0.6.0.`
+ """ # noqa: E501
+
+ if self.backend == 'matplotlib':
+ super().draw_texts(
+ texts=texts,
+ positions=positions,
+ font_sizes=font_sizes,
+ colors=colors,
+ vertical_alignments=vertical_alignments,
+ horizontal_alignments=horizontal_alignments,
+ bboxes=bboxes,
+ **kwargs)
+
+ elif self.backend == 'opencv':
+ font_scale = max(0.1, font_sizes / 30)
+ thickness = max(1, font_sizes // 15)
+
+ text_size, text_baseline = cv2.getTextSize(texts,
+ cv2.FONT_HERSHEY_DUPLEX,
+ font_scale, thickness)
+
+ x = int(positions[0])
+ if horizontal_alignments == 'right':
+ x = max(0, x - text_size[0])
+ y = int(positions[1])
+ if vertical_alignments == 'top':
+ y = min(self.height, y + text_size[1])
+
+ if bboxes is not None:
+ bbox_color = bboxes[0]['facecolor']
+ if isinstance(bbox_color, str):
+ bbox_color = mmcv.color_val(bbox_color)
+
+ y = y - text_baseline // 2
+ self._image = cv2.rectangle(
+ self._image, (x, y - text_size[1] - text_baseline // 2),
+ (x + text_size[0], y + text_baseline // 2), bbox_color,
+ cv2.FILLED)
+
+ self._image = cv2.putText(self._image, texts, (x, y),
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale,
+ colors, thickness - 1)
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
+
+ @master_only
+ def draw_bboxes(self,
+ bboxes: Union[np.ndarray, torch.Tensor],
+ edge_colors: Union[str, tuple, List[str],
+ List[tuple]] = 'g',
+ line_widths: Union[Union[int, float],
+ List[Union[int, float]]] = 2,
+ **kwargs) -> 'Visualizer':
+ """Draw single or multiple bboxes.
+
+ Args:
+ bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with
+ the format of(x1,y1,x2,y2).
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of bboxes. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value, all
+ the lines will have the same colors. Refer to `matplotlib.
+ colors` for full list of formats that are accepted.
+ Defaults to 'g'.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
+ value, all the lines will have the same linestyle.
+ Reference to
+ https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
+ for more details. Defaults to '-'.
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
+ the same length with lines or just single value.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
+ face_colors (Union[str, tuple, List[str], List[tuple]]):
+ The face colors. Defaults to None.
+ alpha (Union[int, float]): The transparency of bboxes.
+ Defaults to 0.8.
+ """
+ if self.backend == 'matplotlib':
+ super().draw_bboxes(
+ bboxes=bboxes,
+ edge_colors=edge_colors,
+ line_widths=line_widths,
+ **kwargs)
+
+ elif self.backend == 'opencv':
+ self._image = mmcv.imshow_bboxes(
+ self._image,
+ bboxes,
+ edge_colors,
+ top_k=-1,
+ thickness=line_widths,
+ show=False)
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
+
+ @master_only
+ def draw_lines(self,
+ x_datas: Union[np.ndarray, torch.Tensor],
+ y_datas: Union[np.ndarray, torch.Tensor],
+ colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ line_widths: Union[Union[int, float],
+ List[Union[int, float]]] = 2,
+ **kwargs) -> 'Visualizer':
+ """Draw single or multiple line segments.
+
+ Args:
+ x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of
+ each line' start and end points.
+ y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of
+ each line' start and end points.
+ colors (Union[str, tuple, List[str], List[tuple]]): The colors of
+ lines. ``colors`` can have the same length with lines or just
+ single value. If ``colors`` is single value, all the lines
+ will have the same colors. Reference to
+ https://matplotlib.org/stable/gallery/color/named_colors.html
+ for more details. Defaults to 'g'.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
+ value, all the lines will have the same linestyle.
+ Reference to
+ https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
+ for more details. Defaults to '-'.
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
+ the same length with lines or just single value.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
+ """
+ if self.backend == 'matplotlib':
+ super().draw_lines(
+ x_datas=x_datas,
+ y_datas=y_datas,
+ colors=colors,
+ line_widths=line_widths,
+ **kwargs)
+
+ elif self.backend == 'opencv':
+
+ self._image = cv2.line(
+ self._image, (x_datas[0], y_datas[0]),
+ (x_datas[1], y_datas[1]),
+ colors,
+ thickness=line_widths)
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
+
+ @master_only
+ def draw_polygons(self,
+ polygons: Union[Union[np.ndarray, torch.Tensor],
+ List[Union[np.ndarray, torch.Tensor]]],
+ edge_colors: Union[str, tuple, List[str],
+ List[tuple]] = 'g',
+ **kwargs) -> 'Visualizer':
+ """Draw single or multiple bboxes.
+
+ Args:
+ polygons (Union[Union[np.ndarray, torch.Tensor],\
+ List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw
+ with the format of (x1,y1,x2,y2,...,xn,yn).
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of polygons. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value,
+ all the lines will have the same colors. Refer to
+ `matplotlib.colors` for full list of formats that are accepted.
+ Defaults to 'g.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
+ value, all the lines will have the same linestyle.
+ Reference to
+ https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
+ for more details. Defaults to '-'.
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
+ the same length with lines or just single value.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
+ face_colors (Union[str, tuple, List[str], List[tuple]]):
+ The face colors. Defaults to None.
+ alpha (Union[int, float]): The transparency of polygons.
+ Defaults to 0.8.
+ """
+ if self.backend == 'matplotlib':
+ super().draw_polygons(
+ polygons=polygons, edge_colors=edge_colors, **kwargs)
+
+ elif self.backend == 'opencv':
+
+ self._image = cv2.fillConvexPoly(self._image, polygons,
+ edge_colors)
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
+
+ @master_only
+ def show(self,
+ drawn_img: Optional[np.ndarray] = None,
+ win_name: str = 'image',
+ wait_time: float = 0.,
+ continue_key=' ') -> None:
+ """Show the drawn image.
+
+ Args:
+ drawn_img (np.ndarray, optional): The image to show. If drawn_img
+ is None, it will show the image got by Visualizer. Defaults
+ to None.
+ win_name (str): The image title. Defaults to 'image'.
+ wait_time (float): Delay in seconds. 0 is the special
+ value that means "forever". Defaults to 0.
+ continue_key (str): The key for users to continue. Defaults to
+ the space key.
+ """
+ if self.backend == 'matplotlib':
+ super().show(
+ drawn_img=drawn_img,
+ win_name=win_name,
+ wait_time=wait_time,
+ continue_key=continue_key)
+
+ elif self.backend == 'opencv':
+ # Keep images are shown in the same window, and the title of window
+ # will be updated with `win_name`.
+ if not hasattr(self, win_name):
+ self._cv_win_name = win_name
+ cv2.namedWindow(winname=f'{id(self)}')
+ cv2.setWindowTitle(f'{id(self)}', win_name)
+ else:
+ cv2.setWindowTitle(f'{id(self)}', win_name)
+ shown_img = self.get_image() if drawn_img is None else drawn_img
+ cv2.imshow(str(id(self)), mmcv.bgr2rgb(shown_img))
+ cv2.waitKey(int(np.ceil(wait_time * 1000)))
+ else:
+ raise ValueError(f'got unsupported backend {self.backend}')
diff --git a/mmpose/visualization/simcc_vis.py b/mmpose/visualization/simcc_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5b602fb5c4ffe2a46ddb2cf09a2cd4501b1664
--- /dev/null
+++ b/mmpose/visualization/simcc_vis.py
@@ -0,0 +1,136 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Union
+
+import cv2 as cv
+import numpy as np
+import torch
+from torchvision.transforms import ToPILImage
+
+
+class SimCCVisualizer:
+
+ def draw_instance_xy_heatmap(self,
+ heatmap: torch.Tensor,
+ overlaid_image: Optional[np.ndarray],
+ n: int = 20,
+ mix: bool = True,
+ weight: float = 0.5):
+ """Draw heatmaps of GT or prediction.
+
+ Args:
+ heatmap (torch.Tensor): Tensor of heatmap.
+ overlaid_image (np.ndarray): The image to draw.
+ n (int): Number of keypoint, up to 20.
+ mix (bool):Whether to merge heatmap and original image.
+ weight (float): Weight of original image during fusion.
+
+ Returns:
+ np.ndarray: the drawn image which channel is RGB.
+ """
+ heatmap2d = heatmap.data.max(0, keepdim=True)[0]
+ xy_heatmap, K = self.split_simcc_xy(heatmap)
+ K = K if K <= n else n
+ blank_size = tuple(heatmap.size()[1:])
+ maps = {'x': [], 'y': []}
+ for i in xy_heatmap:
+ x, y = self.draw_1d_heatmaps(i['x']), self.draw_1d_heatmaps(i['y'])
+ maps['x'].append(x)
+ maps['y'].append(y)
+ white = self.creat_blank(blank_size, K)
+ map2d = self.draw_2d_heatmaps(heatmap2d)
+ if mix:
+ map2d = cv.addWeighted(overlaid_image, 1 - weight, map2d, weight,
+ 0)
+ self.image_cover(white, map2d, int(blank_size[1] * 0.1),
+ int(blank_size[0] * 0.1))
+ white = self.add_1d_heatmaps(maps, white, blank_size, K)
+ return white
+
+ def split_simcc_xy(self, heatmap: Union[np.ndarray, torch.Tensor]):
+ """Extract one-dimensional heatmap from two-dimensional heatmap and
+ calculate the number of keypoint."""
+ size = heatmap.size()
+ k = size[0] if size[0] <= 20 else 20
+ maps = []
+ for _ in range(k):
+ xy_dict = {}
+ single_heatmap = heatmap[_]
+ xy_dict['x'], xy_dict['y'] = self.merge_maps(single_heatmap)
+ maps.append(xy_dict)
+ return maps, k
+
+ def merge_maps(self, map_2d):
+ """Synthesis of one-dimensional heatmap."""
+ x = map_2d.data.max(0, keepdim=True)[0]
+ y = map_2d.data.max(1, keepdim=True)[0]
+ return x, y
+
+ def draw_1d_heatmaps(self, heatmap_1d):
+ """Draw one-dimensional heatmap."""
+ size = heatmap_1d.size()
+ length = max(size)
+ np_heatmap = ToPILImage()(heatmap_1d).convert('RGB')
+ cv_img = cv.cvtColor(np.asarray(np_heatmap), cv.COLOR_RGB2BGR)
+ if size[0] < size[1]:
+ cv_img = cv.resize(cv_img, (length, 15))
+ else:
+ cv_img = cv.resize(cv_img, (15, length))
+ single_map = cv.applyColorMap(cv_img, cv.COLORMAP_JET)
+ return single_map
+
+ def creat_blank(self,
+ size: Union[list, tuple],
+ K: int = 20,
+ interval: int = 10):
+ """Create the background."""
+ blank_height = int(
+ max(size[0] * 2, size[0] * 1.1 + (K + 1) * (15 + interval)))
+ blank_width = int(
+ max(size[1] * 2, size[1] * 1.1 + (K + 1) * (15 + interval)))
+ blank = np.zeros((blank_height, blank_width, 3), np.uint8)
+ blank.fill(255)
+ return blank
+
+ def draw_2d_heatmaps(self, heatmap_2d):
+ """Draw a two-dimensional heatmap fused with the original image."""
+ np_heatmap = ToPILImage()(heatmap_2d).convert('RGB')
+ cv_img = cv.cvtColor(np.asarray(np_heatmap), cv.COLOR_RGB2BGR)
+ map_2d = cv.applyColorMap(cv_img, cv.COLORMAP_JET)
+ return map_2d
+
+ def image_cover(self, background: np.ndarray, foreground: np.ndarray,
+ x: int, y: int):
+ """Paste the foreground on the background."""
+ fore_size = foreground.shape
+ background[y:y + fore_size[0], x:x + fore_size[1]] = foreground
+ return background
+
+ def add_1d_heatmaps(self,
+ maps: dict,
+ background: np.ndarray,
+ map2d_size: Union[tuple, list],
+ K: int,
+ interval: int = 10):
+ """Paste one-dimensional heatmaps onto the background in turn."""
+ y_startpoint, x_startpoint = [int(1.1*map2d_size[1]),
+ int(0.1*map2d_size[0])],\
+ [int(0.1*map2d_size[1]),
+ int(1.1*map2d_size[0])]
+ x_startpoint[1] += interval * 2
+ y_startpoint[0] += interval * 2
+ add = interval + 10
+ for i in range(K):
+ self.image_cover(background, maps['x'][i], x_startpoint[0],
+ x_startpoint[1])
+ cv.putText(background, str(i),
+ (x_startpoint[0] - 30, x_startpoint[1] + 10),
+ cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
+ self.image_cover(background, maps['y'][i], y_startpoint[0],
+ y_startpoint[1])
+ cv.putText(background, str(i),
+ (y_startpoint[0], y_startpoint[1] - 5),
+ cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
+ x_startpoint[1] += add
+ y_startpoint[0] += add
+ return background[:x_startpoint[1] + y_startpoint[1] +
+ 1, :y_startpoint[0] + x_startpoint[0] + 1]
diff --git a/mmpretrain/__init__.py b/mmpretrain/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d065db1cd80c874b9fa93b84752575dd33e457b
--- /dev/null
+++ b/mmpretrain/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import mmengine
+from mmengine.utils import digit_version
+
+from .apis import * # noqa: F401, F403
+from .version import __version__
+
+mmcv_minimum_version = '2.0.0'
+mmcv_maximum_version = '2.1.0'
+mmcv_version = digit_version(mmcv.__version__)
+
+mmengine_minimum_version = '0.8.0'
+mmengine_maximum_version = '1.0.0'
+mmengine_version = digit_version(mmengine.__version__)
+
+assert (mmcv_version >= digit_version(mmcv_minimum_version)
+ and mmcv_version < digit_version(mmcv_maximum_version)), \
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
+ f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
+
+assert (mmengine_version >= digit_version(mmengine_minimum_version)
+ and mmengine_version < digit_version(mmengine_maximum_version)), \
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
+ f'<{mmengine_maximum_version}.'
+
+__all__ = ['__version__']
diff --git a/mmpretrain/__pycache__/__init__.cpython-38.pyc b/mmpretrain/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68b860b5bc183e15e5c451735306cb3d0e555b03
Binary files /dev/null and b/mmpretrain/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/__pycache__/registry.cpython-38.pyc b/mmpretrain/__pycache__/registry.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d0a1f36601180e535caba8842463df51bddd032
Binary files /dev/null and b/mmpretrain/__pycache__/registry.cpython-38.pyc differ
diff --git a/mmpretrain/__pycache__/version.cpython-38.pyc b/mmpretrain/__pycache__/version.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..129ab5cbf1c760ae3dd5425ada9c2f393f432b2d
Binary files /dev/null and b/mmpretrain/__pycache__/version.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__init__.py b/mmpretrain/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fbf443772a983c41f7273124f843bdfbb7f0f46
--- /dev/null
+++ b/mmpretrain/apis/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseInferencer
+from .feature_extractor import FeatureExtractor
+from .image_caption import ImageCaptionInferencer
+from .image_classification import ImageClassificationInferencer
+from .image_retrieval import ImageRetrievalInferencer
+from .model import (ModelHub, get_model, inference_model, init_model,
+ list_models)
+from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
+ TextToImageRetrievalInferencer)
+from .nlvr import NLVRInferencer
+from .visual_grounding import VisualGroundingInferencer
+from .visual_question_answering import VisualQuestionAnsweringInferencer
+
+__all__ = [
+ 'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub',
+ 'ImageClassificationInferencer', 'ImageRetrievalInferencer',
+ 'FeatureExtractor', 'ImageCaptionInferencer',
+ 'TextToImageRetrievalInferencer', 'VisualGroundingInferencer',
+ 'VisualQuestionAnsweringInferencer', 'ImageToTextRetrievalInferencer',
+ 'BaseInferencer', 'NLVRInferencer'
+]
diff --git a/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc b/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65bf34ee4d103bc7bda7306c6de0eeb3ff9e8ef0
Binary files /dev/null and b/mmpretrain/apis/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/base.cpython-38.pyc b/mmpretrain/apis/__pycache__/base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cbbf2a9206b4c9cd8b7d92e211afe41f2eccae7d
Binary files /dev/null and b/mmpretrain/apis/__pycache__/base.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc b/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71ef9eb92b29c22d16482e623d7df6722545a017
Binary files /dev/null and b/mmpretrain/apis/__pycache__/feature_extractor.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97405f793331b9d57c219ab8be4da944ddc92418
Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_caption.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c15984ff844b44847a60346c962787c07c682b6a
Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_classification.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc b/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fefb69909a2d7b18db87b3fd24fcc0cceac79c87
Binary files /dev/null and b/mmpretrain/apis/__pycache__/image_retrieval.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/model.cpython-38.pyc b/mmpretrain/apis/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ab396d53f85702d40e68614323c576847384579
Binary files /dev/null and b/mmpretrain/apis/__pycache__/model.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02410fb5c709d99c598a172b4ef824f5aea91ea
Binary files /dev/null and b/mmpretrain/apis/__pycache__/multimodal_retrieval.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc b/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68cd7897dd000482a0a20c3360d8313132873732
Binary files /dev/null and b/mmpretrain/apis/__pycache__/nlvr.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc b/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56aa69cb83e444180b35cc43d1bc2e741c3a963d
Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_grounding.cpython-38.pyc differ
diff --git a/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8befd7f2e8544ae5695c64be347b5b3e67fdbee4
Binary files /dev/null and b/mmpretrain/apis/__pycache__/visual_question_answering.cpython-38.pyc differ
diff --git a/mmpretrain/apis/base.py b/mmpretrain/apis/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bff6bd18675a3a0996dcd09081a15728311657f
--- /dev/null
+++ b/mmpretrain/apis/base.py
@@ -0,0 +1,390 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from math import ceil
+from typing import Callable, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from mmengine.config import Config
+from mmengine.dataset import default_collate
+from mmengine.fileio import get_file_backend
+from mmengine.model import BaseModel
+from mmengine.runner import load_checkpoint
+
+from mmpretrain.structures import DataSample
+from mmpretrain.utils import track
+from .model import get_model, list_models
+
+ModelType = Union[BaseModel, str, Config]
+InputType = Union[str, np.ndarray, list]
+
+
+class BaseInferencer:
+ """Base inferencer for various tasks.
+
+ The BaseInferencer provides the standard workflow for inference as follows:
+
+ 1. Preprocess the input data by :meth:`preprocess`.
+ 2. Forward the data to the model by :meth:`forward`. ``BaseInferencer``
+ assumes the model inherits from :class:`mmengine.models.BaseModel` and
+ will call `model.test_step` in :meth:`forward` by default.
+ 3. Visualize the results by :meth:`visualize`.
+ 4. Postprocess and return the results by :meth:`postprocess`.
+
+ When we call the subclasses inherited from BaseInferencer (not overriding
+ ``__call__``), the workflow will be executed in order.
+
+ All subclasses of BaseInferencer could define the following class
+ attributes for customization:
+
+ - ``preprocess_kwargs``: The keys of the kwargs that will be passed to
+ :meth:`preprocess`.
+ - ``forward_kwargs``: The keys of the kwargs that will be passed to
+ :meth:`forward`
+ - ``visualize_kwargs``: The keys of the kwargs that will be passed to
+ :meth:`visualize`
+ - ``postprocess_kwargs``: The keys of the kwargs that will be passed to
+ :meth:`postprocess`
+
+ All attributes mentioned above should be a ``set`` of keys (strings),
+ and each key should not be duplicated. Actually, :meth:`__call__` will
+ dispatch all the arguments to the corresponding methods according to the
+ ``xxx_kwargs`` mentioned above.
+
+ Subclasses inherited from ``BaseInferencer`` should implement
+ :meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`:
+
+ - _init_pipeline: Return a callable object to preprocess the input data.
+ - visualize: Visualize the results returned by :meth:`forward`.
+ - postprocess: Postprocess the results returned by :meth:`forward` and
+ :meth:`visualize`.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``cls.list_models()`` and you can also query it in
+ :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str | torch.device | None): Transfer the model to the target
+ device. Defaults to None.
+ device_map (str | dict | None): A map that specifies where each
+ submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every
+ submodule of it will be sent to the same device. You can use
+ `device_map="auto"` to automatically generate the device map.
+ Defaults to None.
+ offload_folder (str | None): If the `device_map` contains any value
+ `"disk"`, the folder where we will offload weights.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+ """
+
+ preprocess_kwargs: set = set()
+ forward_kwargs: set = set()
+ visualize_kwargs: set = set()
+ postprocess_kwargs: set = set()
+
+ def __init__(self,
+ model: ModelType,
+ pretrained: Union[bool, str] = True,
+ device: Union[str, torch.device, None] = None,
+ device_map=None,
+ offload_folder=None,
+ **kwargs) -> None:
+
+ if isinstance(model, BaseModel):
+ if isinstance(pretrained, str):
+ load_checkpoint(model, pretrained, map_location='cpu')
+ if device_map is not None:
+ from .utils import dispatch_model
+ model = dispatch_model(
+ model,
+ device_map=device_map,
+ offload_folder=offload_folder)
+ elif device is not None:
+ model.to(device)
+ else:
+ model = get_model(
+ model,
+ pretrained,
+ device=device,
+ device_map=device_map,
+ offload_folder=offload_folder,
+ **kwargs)
+
+ model.eval()
+
+ self.config = model._config
+ self.model = model
+ self.pipeline = self._init_pipeline(self.config)
+ self.visualizer = None
+
+ def __call__(
+ self,
+ inputs,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs,
+ ) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+ return_datasamples (bool): Whether to return results as
+ :obj:`BaseDataElement`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
+ Each key in kwargs should be in the corresponding set of
+ ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
+ and ``postprocess_kwargs``.
+
+ Returns:
+ dict: Inference and visualization results.
+ """
+ (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ ) = self._dispatch_kwargs(**kwargs)
+
+ ori_inputs = self._inputs_to_list(inputs)
+ inputs = self.preprocess(
+ ori_inputs, batch_size=batch_size, **preprocess_kwargs)
+ preds = []
+ for data in track(
+ inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)):
+ preds.extend(self.forward(data, **forward_kwargs))
+ visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
+ results = self.postprocess(preds, visualization, return_datasamples,
+ **postprocess_kwargs)
+ return results
+
+ def _inputs_to_list(self, inputs: InputType) -> list:
+ """Preprocess the inputs to a list.
+
+ Cast the input data to a list of data.
+
+ - list or tuple: return inputs
+ - str:
+ - Directory path: return all files in the directory
+ - other cases: return a list containing the string. The string
+ could be a path to file, a url or other types of string according
+ to the task.
+ - other: return a list with one item.
+
+ Args:
+ inputs (str | array | list): Inputs for the inferencer.
+
+ Returns:
+ list: List of input for the :meth:`preprocess`.
+ """
+ if isinstance(inputs, str):
+ backend = get_file_backend(inputs)
+ if hasattr(backend, 'isdir') and backend.isdir(inputs):
+ # Backends like HttpsBackend do not implement `isdir`, so only
+ # those backends that implement `isdir` could accept the inputs
+ # as a directory
+ file_list = backend.list_dir_or_file(inputs, list_dir=False)
+ inputs = [
+ backend.join_path(inputs, file) for file in file_list
+ ]
+
+ if not isinstance(inputs, (list, tuple)):
+ inputs = [inputs]
+
+ return list(inputs)
+
+ def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs):
+ """Process the inputs into a model-feedable format.
+
+ Customize your preprocess by overriding this method. Preprocess should
+ return an iterable object, of which each item will be used as the
+ input of ``model.test_step``.
+
+ ``BaseInferencer.preprocess`` will return an iterable chunked data,
+ which will be used in __call__ like this:
+
+ .. code-block:: python
+
+ def __call__(self, inputs, batch_size=1, **kwargs):
+ chunked_data = self.preprocess(inputs, batch_size, **kwargs)
+ for batch in chunked_data:
+ preds = self.forward(batch, **kwargs)
+
+ Args:
+ inputs (InputsType): Inputs given by user.
+ batch_size (int): batch size. Defaults to 1.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``default_collate``.
+ """
+ chunked_data = self._get_chunk_data(
+ map(self.pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ @torch.no_grad()
+ def forward(self, inputs: Union[dict, tuple], **kwargs):
+ """Feed the inputs to the model."""
+ return self.model.test_step(inputs)
+
+ def visualize(self,
+ inputs: list,
+ preds: List[DataSample],
+ show: bool = False,
+ **kwargs) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Customize your visualization by overriding this method. visualize
+ should return visualization results, which could be np.ndarray or any
+ other objects.
+
+ Args:
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
+ preds (Any): Predictions of the model.
+ show (bool): Whether to display the image in a popup window.
+ Defaults to False.
+
+ Returns:
+ List[np.ndarray]: Visualization results.
+ """
+ if show:
+ raise NotImplementedError(
+ f'The `visualize` method of {self.__class__.__name__} '
+ 'is not implemented.')
+
+ @abstractmethod
+ def postprocess(
+ self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasample=False,
+ **kwargs,
+ ) -> dict:
+ """Process the predictions and visualization results from ``forward``
+ and ``visualize``.
+
+ This method should be responsible for the following tasks:
+
+ 1. Convert datasamples into a json-serializable dict if needed.
+ 2. Pack the predictions and visualization results and return them.
+ 3. Dump or log the predictions.
+
+ Customize your postprocess by overriding this method. Make sure
+ ``postprocess`` will return a dict with visualization results and
+ inference results.
+
+ Args:
+ preds (List[Dict]): Predictions of the model.
+ visualization (np.ndarray): Visualized predictions.
+ return_datasample (bool): Whether to return results as datasamples.
+ Defaults to False.
+
+ Returns:
+ dict: Inference and visualization results with key ``predictions``
+ and ``visualization``
+
+ - ``visualization (Any)``: Returned by :meth:`visualize`
+ - ``predictions`` (dict or DataSample): Returned by
+ :meth:`forward` and processed in :meth:`postprocess`.
+ If ``return_datasample=False``, it usually should be a
+ json-serializable dict containing only basic data elements such
+ as strings and numbers.
+ """
+
+ @abstractmethod
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ """Initialize the test pipeline.
+
+ Return a pipeline to handle various input data, such as ``str``,
+ ``np.ndarray``. It is an abstract method in BaseInferencer, and should
+ be implemented in subclasses.
+
+ The returned pipeline will be used to process a single data.
+ It will be used in :meth:`preprocess` like this:
+
+ .. code-block:: python
+ def preprocess(self, inputs, batch_size, **kwargs):
+ ...
+ dataset = map(self.pipeline, dataset)
+ ...
+ """
+
+ def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
+ """Get batch data from dataset.
+
+ Args:
+ inputs (Iterable): An iterable dataset.
+ chunk_size (int): Equivalent to batch size.
+
+ Yields:
+ list: batch data.
+ """
+ inputs_iter = iter(inputs)
+ while True:
+ try:
+ chunk_data = []
+ for _ in range(chunk_size):
+ processed_data = next(inputs_iter)
+ chunk_data.append(processed_data)
+ yield chunk_data
+ except StopIteration:
+ if chunk_data:
+ yield chunk_data
+ break
+
+ def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]:
+ """Dispatch kwargs to preprocess(), forward(), visualize() and
+ postprocess() according to the actual demands.
+
+ Returns:
+ Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess,
+ forward, visualize and postprocess respectively.
+ """
+ # Ensure each argument only matches one function
+ method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \
+ self.visualize_kwargs | self.postprocess_kwargs
+
+ union_kwargs = method_kwargs | set(kwargs.keys())
+ if union_kwargs != method_kwargs:
+ unknown_kwargs = union_kwargs - method_kwargs
+ raise ValueError(
+ f'unknown argument {unknown_kwargs} for `preprocess`, '
+ '`forward`, `visualize` and `postprocess`')
+
+ preprocess_kwargs = {}
+ forward_kwargs = {}
+ visualize_kwargs = {}
+ postprocess_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key in self.preprocess_kwargs:
+ preprocess_kwargs[key] = value
+ if key in self.forward_kwargs:
+ forward_kwargs[key] = value
+ if key in self.visualize_kwargs:
+ visualize_kwargs[key] = value
+ if key in self.postprocess_kwargs:
+ postprocess_kwargs[key] = value
+
+ return (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ )
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List models defined in metafile of corresponding packages.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern)
diff --git a/mmpretrain/apis/feature_extractor.py b/mmpretrain/apis/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee14f92f489497dd036fe0567786a94207924d4a
--- /dev/null
+++ b/mmpretrain/apis/feature_extractor.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, List, Optional, Union
+
+import torch
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from .base import BaseInferencer, InputType
+from .model import list_models
+
+
+class FeatureExtractor(BaseInferencer):
+ """The inferencer for extract features.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``FeatureExtractor.list_models()`` and you can also query it in
+ :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import FeatureExtractor
+ >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
+ >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0]
+ >>> for feat in feats:
+ >>> print(feat.shape)
+ torch.Size([256, 56, 56])
+ torch.Size([512, 28, 28])
+ torch.Size([1024, 14, 14])
+ torch.Size([2048, 7, 7])
+ """ # noqa: E501
+
+ def __call__(self,
+ inputs: InputType,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (str | array | list): The image path or array, or a list of
+ images.
+ batch_size (int): Batch size. Defaults to 1.
+ **kwargs: Other keyword arguments accepted by the `extract_feat`
+ method of the model.
+
+ Returns:
+ tensor | Tuple[tensor]: The extracted features.
+ """
+ ori_inputs = self._inputs_to_list(inputs)
+ inputs = self.preprocess(ori_inputs, batch_size=batch_size)
+ preds = []
+ for data in inputs:
+ preds.extend(self.forward(data, **kwargs))
+
+ return preds
+
+ @torch.no_grad()
+ def forward(self, inputs: Union[dict, tuple], **kwargs):
+ inputs = self.model.data_preprocessor(inputs, False)['inputs']
+ outputs = self.model.extract_feat(inputs, **kwargs)
+
+ def scatter(feats, index):
+ if isinstance(feats, torch.Tensor):
+ return feats[index]
+ else:
+ # Sequence of tensor
+ return type(feats)([scatter(item, index) for item in feats])
+
+ results = []
+ for i in range(inputs.shape[0]):
+ results.append(scatter(outputs, i))
+
+ return results
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
+
+ def load_image(input_):
+ img = imread(input_)
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return dict(
+ img=img,
+ img_shape=img.shape[:2],
+ ori_shape=img.shape[:2],
+ )
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self):
+ raise NotImplementedError(
+ "The FeatureExtractor doesn't support visualization.")
+
+ def postprocess(self):
+ raise NotImplementedError(
+ "The FeatureExtractor doesn't need postprocessing.")
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern)
diff --git a/mmpretrain/apis/image_caption.py b/mmpretrain/apis/image_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..c11c0d3044d9924aba159782309d2cc20f1745bc
--- /dev/null
+++ b/mmpretrain/apis/image_caption.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Callable, List, Optional
+
+import numpy as np
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer, InputType
+from .model import list_models
+
+
+class ImageCaptionInferencer(BaseInferencer):
+ """The inferencer for image caption.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``ImageCaptionInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import ImageCaptionInferencer
+ >>> inferencer = ImageCaptionInferencer('blip-base_3rdparty_caption')
+ >>> inferencer('demo/cat-dog.png')[0]
+ {'pred_caption': 'a puppy and a cat sitting on a blanket'}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
+
+ def __call__(self,
+ images: InputType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ images (str | array | list): The image path or array, or a list of
+ images.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the short edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the prediction scores
+ of prediction categories. Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ return super().__call__(images, return_datasamples, batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
+
+ def load_image(input_):
+ img = imread(input_)
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return dict(
+ img=img,
+ img_shape=img.shape[:2],
+ ori_shape=img.shape[:2],
+ )
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[InputType],
+ preds: List[DataSample],
+ show: bool = False,
+ wait_time: int = 0,
+ resize: Optional[int] = None,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_)
+ if isinstance(input_, str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_image_caption(
+ image,
+ data_sample,
+ resize=resize,
+ show=show,
+ wait_time=wait_time,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ results.append({'pred_caption': data_sample.get('pred_caption')})
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Image Caption')
diff --git a/mmpretrain/apis/image_classification.py b/mmpretrain/apis/image_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..a20218071c7afc90c6a46d61b5ed3a8fee5bc012
--- /dev/null
+++ b/mmpretrain/apis/image_classification.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer, InputType, ModelType
+from .model import list_models
+
+
+class ImageClassificationInferencer(BaseInferencer):
+ """The inferencer for image classification.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``ImageClassificationInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ 1. Use a pre-trained model in MMPreTrain to inference an image.
+
+ >>> from mmpretrain import ImageClassificationInferencer
+ >>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
+ >>> inferencer('demo/demo.JPEG')
+ [{'pred_score': array([...]),
+ 'pred_label': 65,
+ 'pred_score': 0.6649367809295654,
+ 'pred_class': 'sea snake'}]
+
+ 2. Use a config file and checkpoint to inference multiple images on GPU,
+ and save the visualization results in a folder.
+
+ >>> from mmpretrain import ImageClassificationInferencer
+ >>> inferencer = ImageClassificationInferencer(
+ model='configs/resnet/resnet50_8xb32_in1k.py',
+ pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
+ device='cuda')
+ >>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
+ """ # noqa: E501
+
+ visualize_kwargs: set = {
+ 'resize', 'rescale_factor', 'draw_score', 'show', 'show_dir',
+ 'wait_time'
+ }
+
+ def __init__(self,
+ model: ModelType,
+ pretrained: Union[bool, str] = True,
+ device: Union[str, torch.device, None] = None,
+ classes=None,
+ **kwargs) -> None:
+ super().__init__(
+ model=model, pretrained=pretrained, device=device, **kwargs)
+
+ if classes is not None:
+ self.classes = classes
+ else:
+ self.classes = getattr(self.model, '_dataset_meta',
+ {}).get('classes')
+
+ def __call__(self,
+ inputs: InputType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (str | array | list): The image path or array, or a list of
+ images.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the short edge of the image to the
+ specified length before visualization. Defaults to None.
+ rescale_factor (float, optional): Rescale the image by the rescale
+ factor for visualization. This is helpful when the image is too
+ large or too small for visualization. Defaults to None.
+ draw_score (bool): Whether to draw the prediction scores
+ of prediction categories. Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ return super().__call__(
+ inputs,
+ return_datasamples=return_datasamples,
+ batch_size=batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
+
+ def load_image(input_):
+ img = imread(input_)
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return dict(
+ img=img,
+ img_shape=img.shape[:2],
+ ori_shape=img.shape[:2],
+ )
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[InputType],
+ preds: List[DataSample],
+ show: bool = False,
+ wait_time: int = 0,
+ resize: Optional[int] = None,
+ rescale_factor: Optional[float] = None,
+ draw_score=True,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_)
+ if isinstance(input_, str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_cls(
+ image,
+ data_sample,
+ classes=self.classes,
+ resize=resize,
+ show=show,
+ wait_time=wait_time,
+ rescale_factor=rescale_factor,
+ draw_gt=False,
+ draw_pred=True,
+ draw_score=draw_score,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ pred_scores = data_sample.pred_score
+ pred_score = float(torch.max(pred_scores).item())
+ pred_label = torch.argmax(pred_scores).item()
+ result = {
+ 'pred_scores': pred_scores.detach().cpu().numpy(),
+ 'pred_label': pred_label,
+ 'pred_score': pred_score,
+ }
+ if self.classes is not None:
+ result['pred_class'] = self.classes[pred_label]
+ results.append(result)
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Image Classification')
diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..deae1de7975b1e46ccf045bee7fec7ddcbfbfea4
--- /dev/null
+++ b/mmpretrain/apis/image_retrieval.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import BaseDataset, Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer, InputType, ModelType
+from .model import list_models
+
+
+class ImageRetrievalInferencer(BaseInferencer):
+ """The inferencer for image to image retrieval.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``ImageRetrievalInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ prototype (str | list | dict | DataLoader, BaseDataset): The images to
+ be retrieved. It can be the following types:
+
+ - str: The directory of the the images.
+ - list: A list of path of the images.
+ - dict: A config dict of the a prototype dataset.
+ - BaseDataset: A prototype dataset.
+ - DataLoader: A data loader to load the prototype data.
+
+ prototype_cache (str, optional): The path of the generated prototype
+ features. If exists, directly load the cache instead of re-generate
+ the prototype features. If not exists, save the generated features
+ to the path. Defaults to None.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import ImageRetrievalInferencer
+ >>> inferencer = ImageRetrievalInferencer(
+ ... 'resnet50-arcface_inshop',
+ ... prototype='./demo/',
+ ... prototype_cache='img_retri.pth')
+ >>> inferencer('demo/cat-dog.png', topk=2)[0][1]
+ {'match_score': tensor(0.4088, device='cuda:0'),
+ 'sample_idx': 3,
+ 'sample': {'img_path': './demo/dog.jpg'}}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {
+ 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
+ }
+ postprocess_kwargs: set = {'topk'}
+
+ def __init__(
+ self,
+ model: ModelType,
+ prototype,
+ prototype_cache=None,
+ prepare_batch_size=8,
+ pretrained: Union[bool, str] = True,
+ device: Union[str, torch.device, None] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ model=model, pretrained=pretrained, device=device, **kwargs)
+
+ self.prototype_dataset = self._prepare_prototype(
+ prototype, prototype_cache, prepare_batch_size)
+
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
+ from mmengine.dataset import DefaultSampler
+ from torch.utils.data import DataLoader
+
+ def build_dataloader(dataset):
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ collate_fn=default_collate,
+ sampler=DefaultSampler(dataset, shuffle=False),
+ persistent_workers=False,
+ )
+
+ if isinstance(prototype, str):
+ # A directory path of images
+ prototype = dict(
+ type='CustomDataset', with_label=False, data_root=prototype)
+
+ if isinstance(prototype, list):
+ test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
+ dataset = BaseDataset(
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
+ dataset.data_list = [{
+ 'sample_idx': i,
+ 'img_path': file
+ } for i, file in enumerate(prototype)]
+ dataset._fully_initialized = True
+ dataloader = build_dataloader(dataset)
+ elif isinstance(prototype, dict):
+ # A config of dataset
+ from mmpretrain.registry import DATASETS
+ test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
+ dataset = DATASETS.build(prototype)
+ dataloader = build_dataloader(dataset)
+ elif isinstance(prototype, DataLoader):
+ dataset = prototype.dataset
+ dataloader = prototype
+ elif isinstance(prototype, BaseDataset):
+ dataset = prototype
+ dataloader = build_dataloader(dataset)
+ else:
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
+
+ if cache is not None and Path(cache).exists():
+ self.model.prototype = cache
+ else:
+ self.model.prototype = dataloader
+ self.model.prepare_prototype()
+
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ if cache is None:
+ logger.info('The prototype has been prepared, you can use '
+ '`save_prototype` to dump it into a pickle '
+ 'file for the future usage.')
+ elif not Path(cache).exists():
+ self.save_prototype(cache)
+ logger.info(f'The prototype has been saved at {cache}.')
+
+ return dataset
+
+ def save_prototype(self, path):
+ self.model.dump_prototype(path)
+
+ def __call__(self,
+ inputs: InputType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (str | array | list): The image path or array, or a list of
+ images.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the long edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the match scores.
+ Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ return super().__call__(inputs, return_datasamples, batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[InputType], batch_size: int = 1):
+
+ def load_image(input_):
+ img = imread(input_)
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return dict(
+ img=img,
+ img_shape=img.shape[:2],
+ ori_shape=img.shape[:2],
+ )
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[InputType],
+ preds: List[DataSample],
+ topk: int = 3,
+ resize: Optional[int] = 224,
+ show: bool = False,
+ wait_time: int = 0,
+ draw_score=True,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_)
+ if isinstance(input_, str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_image_retrieval(
+ image,
+ data_sample,
+ self.prototype_dataset,
+ topk=topk,
+ resize=resize,
+ draw_score=draw_score,
+ show=show,
+ wait_time=wait_time,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(
+ self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False,
+ topk=1,
+ ) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
+ matches = []
+ for match_score, sample_idx in zip(match_scores, indices):
+ sample = self.prototype_dataset.get_data_info(
+ sample_idx.item())
+ sample_idx = sample.pop('sample_idx')
+ matches.append({
+ 'match_score': match_score,
+ 'sample_idx': sample_idx,
+ 'sample': sample
+ })
+ results.append(matches)
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Image Retrieval')
diff --git a/mmpretrain/apis/model.py b/mmpretrain/apis/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba475e7f791f42eb9aec384afec947f72722f27
--- /dev/null
+++ b/mmpretrain/apis/model.py
@@ -0,0 +1,408 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import fnmatch
+import os.path as osp
+import re
+import warnings
+from os import PathLike
+from pathlib import Path
+from typing import List, Tuple, Union
+
+from mmengine.config import Config
+from modelindex.load_model_index import load
+from modelindex.models.Model import Model
+
+
+class ModelHub:
+ """A hub to host the meta information of all pre-defined models."""
+ _models_dict = {}
+ __mmpretrain_registered = False
+
+ @classmethod
+ def register_model_index(cls,
+ model_index_path: Union[str, PathLike],
+ config_prefix: Union[str, PathLike, None] = None):
+ """Parse the model-index file and register all models.
+
+ Args:
+ model_index_path (str | PathLike): The path of the model-index
+ file.
+ config_prefix (str | PathLike | None): The prefix of all config
+ file paths in the model-index file.
+ """
+ model_index = load(str(model_index_path))
+ model_index.build_models_with_collections()
+
+ for metainfo in model_index.models:
+ model_name = metainfo.name.lower()
+ if metainfo.name in cls._models_dict:
+ raise ValueError(
+ 'The model name {} is conflict in {} and {}.'.format(
+ model_name, osp.abspath(metainfo.filepath),
+ osp.abspath(cls._models_dict[model_name].filepath)))
+ metainfo.config = cls._expand_config_path(metainfo, config_prefix)
+ cls._models_dict[model_name] = metainfo
+
+ @classmethod
+ def get(cls, model_name):
+ """Get the model's metainfo by the model name.
+
+ Args:
+ model_name (str): The name of model.
+
+ Returns:
+ modelindex.models.Model: The metainfo of the specified model.
+ """
+ cls._register_mmpretrain_models()
+ # lazy load config
+ metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
+ if metainfo is None:
+ raise ValueError(
+ f'Failed to find model "{model_name}". please use '
+ '`mmpretrain.list_models` to get all available names.')
+ if isinstance(metainfo.config, str):
+ metainfo.config = Config.fromfile(metainfo.config)
+ return metainfo
+
+ @staticmethod
+ def _expand_config_path(metainfo: Model,
+ config_prefix: Union[str, PathLike] = None):
+ if config_prefix is None:
+ config_prefix = osp.dirname(metainfo.filepath)
+
+ if metainfo.config is None or osp.isabs(metainfo.config):
+ config_path: str = metainfo.config
+ else:
+ config_path = osp.abspath(osp.join(config_prefix, metainfo.config))
+
+ return config_path
+
+ @classmethod
+ def _register_mmpretrain_models(cls):
+ # register models in mmpretrain
+ if not cls.__mmpretrain_registered:
+ from importlib_metadata import distribution
+ root = distribution('mmpretrain').locate_file('mmpretrain')
+ model_index_path = root / '.mim' / 'model-index.yml'
+ ModelHub.register_model_index(
+ model_index_path, config_prefix=root / '.mim')
+ cls.__mmpretrain_registered = True
+
+ @classmethod
+ def has(cls, model_name):
+ """Whether a model name is in the ModelHub."""
+ return model_name in cls._models_dict
+
+
+def get_model(model: Union[str, Config],
+ pretrained: Union[str, bool] = False,
+ device=None,
+ device_map=None,
+ offload_folder=None,
+ url_mapping: Tuple[str, str] = None,
+ **kwargs):
+ """Get a pre-defined model or create a model from config.
+
+ Args:
+ model (str | Config): The name of model, the config file path or a
+ config instance.
+ pretrained (bool | str): When use name to specify model, you can
+ use ``True`` to load the pre-defined pretrained weights. And you
+ can also use a string to specify the path or link of weights to
+ load. Defaults to False.
+ device (str | torch.device | None): Transfer the model to the target
+ device. Defaults to None.
+ device_map (str | dict | None): A map that specifies where each
+ submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every
+ submodule of it will be sent to the same device. You can use
+ `device_map="auto"` to automatically generate the device map.
+ Defaults to None.
+ offload_folder (str | None): If the `device_map` contains any value
+ `"disk"`, the folder where we will offload weights.
+ url_mapping (Tuple[str, str], optional): The mapping of pretrained
+ checkpoint link. For example, load checkpoint from a local dir
+ instead of download by ``('https://.*/', './checkpoint')``.
+ Defaults to None.
+ **kwargs: Other keyword arguments of the model config.
+
+ Returns:
+ mmengine.model.BaseModel: The result model.
+
+ Examples:
+ Get a ResNet-50 model and extract images feature:
+
+ >>> import torch
+ >>> from mmpretrain import get_model
+ >>> inputs = torch.rand(16, 3, 224, 224)
+ >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
+ >>> feats = model.extract_feat(inputs)
+ >>> for feat in feats:
+ ... print(feat.shape)
+ torch.Size([16, 256])
+ torch.Size([16, 512])
+ torch.Size([16, 1024])
+ torch.Size([16, 2048])
+
+ Get Swin-Transformer model with pre-trained weights and inference:
+
+ >>> from mmpretrain import get_model, inference_model
+ >>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
+ >>> result = inference_model(model, 'demo/demo.JPEG')
+ >>> print(result['pred_class'])
+ 'sea snake'
+ """ # noqa: E501
+ if device_map is not None:
+ from .utils import dispatch_model
+ dispatch_model._verify_require()
+
+ metainfo = None
+ if isinstance(model, Config):
+ config = copy.deepcopy(model)
+ if pretrained is True and 'load_from' in config:
+ pretrained = config.load_from
+ elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
+ config = Config.fromfile(model)
+ if pretrained is True and 'load_from' in config:
+ pretrained = config.load_from
+ elif isinstance(model, str):
+ metainfo = ModelHub.get(model)
+ config = metainfo.config
+ if pretrained is True and metainfo.weights is not None:
+ pretrained = metainfo.weights
+ else:
+ raise TypeError('model must be a name, a path or a Config object, '
+ f'but got {type(config)}')
+
+ if pretrained is True:
+ warnings.warn('Unable to find pre-defined checkpoint of the model.')
+ pretrained = None
+ elif pretrained is False:
+ pretrained = None
+
+ if kwargs:
+ config.merge_from_dict({'model': kwargs})
+ config.model.setdefault('data_preprocessor',
+ config.get('data_preprocessor', None))
+
+ from mmengine.registry import DefaultScope
+
+ from mmpretrain.registry import MODELS
+ with DefaultScope.overwrite_default_scope('mmpretrain'):
+ model = MODELS.build(config.model)
+
+ dataset_meta = {}
+ if pretrained:
+ # Mapping the weights to GPU may cause unexpected video memory leak
+ # which refers to https://github.com/open-mmlab/mmdetection/pull/6405
+ from mmengine.runner import load_checkpoint
+ if url_mapping is not None:
+ pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
+ checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
+ if 'dataset_meta' in checkpoint.get('meta', {}):
+ # mmpretrain 1.x
+ dataset_meta = checkpoint['meta']['dataset_meta']
+ elif 'CLASSES' in checkpoint.get('meta', {}):
+ # mmcls 0.x
+ dataset_meta = {'classes': checkpoint['meta']['CLASSES']}
+
+ if len(dataset_meta) == 0 and 'test_dataloader' in config:
+ from mmpretrain.registry import DATASETS
+ dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
+ dataset_meta = getattr(dataset_class, 'METAINFO', {})
+
+ if device_map is not None:
+ model = dispatch_model(
+ model, device_map=device_map, offload_folder=offload_folder)
+ elif device is not None:
+ model.to(device)
+
+ model._dataset_meta = dataset_meta # save the dataset meta
+ model._config = config # save the config in the model
+ model._metainfo = metainfo # save the metainfo in the model
+ model.eval()
+ return model
+
+
+def init_model(config, checkpoint=None, device=None, **kwargs):
+ """Initialize a classifier from config file (deprecated).
+
+ It's only for compatibility, please use :func:`get_model` instead.
+
+ Args:
+ config (str | :obj:`mmengine.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str | torch.device | None): Transfer the model to the target
+ device. Defaults to None.
+ **kwargs: Other keyword arguments of the model config.
+
+ Returns:
+ nn.Module: The constructed model.
+ """
+ return get_model(config, checkpoint, device, **kwargs)
+
+
+def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
+ """List all models available in MMPretrain.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+ Defaults to None.
+ exclude_patterns (list | None): A list of wildcard patterns to
+ exclude names from the matched names. Defaults to None.
+ task (str | none): The evaluation task of the model.
+
+ Returns:
+ List[str]: a list of model names.
+
+ Examples:
+ List all models:
+
+ >>> from mmpretrain import list_models
+ >>> list_models()
+
+ List ResNet-50 models on ImageNet-1k dataset:
+
+ >>> from mmpretrain import list_models
+ >>> list_models('resnet*in1k')
+ ['resnet50_8xb32_in1k',
+ 'resnet50_8xb32-fp16_in1k',
+ 'resnet50_8xb256-rsb-a1-600e_in1k',
+ 'resnet50_8xb256-rsb-a2-300e_in1k',
+ 'resnet50_8xb256-rsb-a3-100e_in1k']
+
+ List Swin-Transformer models trained from stratch and exclude
+ Swin-Transformer-V2 models:
+
+ >>> from mmpretrain import list_models
+ >>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
+ ['swin-base_16xb64_in1k',
+ 'swin-base_3rdparty_in1k',
+ 'swin-base_3rdparty_in1k-384',
+ 'swin-large_8xb8_cub-384px',
+ 'swin-small_16xb64_in1k',
+ 'swin-small_3rdparty_in1k',
+ 'swin-tiny_16xb64_in1k',
+ 'swin-tiny_3rdparty_in1k']
+
+ List all EVA models for image classification task.
+
+ >>> from mmpretrain import list_models
+ >>> list_models('eva', task='Image Classification')
+ ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
+ 'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
+ 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
+ 'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
+ 'eva-l-p14_mim-pre_3rdparty_in1k-196px',
+ 'eva-l-p14_mim-pre_3rdparty_in1k-336px']
+ """
+ ModelHub._register_mmpretrain_models()
+ matches = set(ModelHub._models_dict.keys())
+
+ if pattern is not None:
+ # Always match keys with any postfix.
+ matches = set(fnmatch.filter(matches, pattern + '*'))
+
+ exclude_patterns = exclude_patterns or []
+ for exclude_pattern in exclude_patterns:
+ exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
+ matches = matches - exclude
+
+ if task is not None:
+ task_matches = []
+ for key in matches:
+ metainfo = ModelHub._models_dict[key]
+ if metainfo.results is None and task == 'null':
+ task_matches.append(key)
+ elif metainfo.results is None:
+ continue
+ elif task in [result.task for result in metainfo.results]:
+ task_matches.append(key)
+ matches = task_matches
+
+ return sorted(list(matches))
+
+
+def inference_model(model, *args, **kwargs):
+ """Inference an image with the inferencer.
+
+ Automatically select inferencer to inference according to the type of
+ model. It's a shortcut for a quick start, and for advanced usage, please
+ use the correspondding inferencer class.
+
+ Here is the mapping from task to inferencer:
+
+ - Image Classification: :class:`ImageClassificationInferencer`
+ - Image Retrieval: :class:`ImageRetrievalInferencer`
+ - Image Caption: :class:`ImageCaptionInferencer`
+ - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
+ - Visual Grounding: :class:`VisualGroundingInferencer`
+ - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
+ - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
+ - NLVR: :class:`NLVRInferencer`
+
+ Args:
+ model (BaseModel | str | Config): The loaded model, the model
+ name or the config of the model.
+ *args: Positional arguments to call the inferencer.
+ **kwargs: Other keyword arguments to initialize and call the
+ correspondding inferencer.
+
+ Returns:
+ result (dict): The inference results.
+ """ # noqa: E501
+ from mmengine.model import BaseModel
+
+ if isinstance(model, BaseModel):
+ metainfo = getattr(model, '_metainfo', None)
+ else:
+ metainfo = ModelHub.get(model)
+
+ from inspect import signature
+
+ from .image_caption import ImageCaptionInferencer
+ from .image_classification import ImageClassificationInferencer
+ from .image_retrieval import ImageRetrievalInferencer
+ from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
+ TextToImageRetrievalInferencer)
+ from .nlvr import NLVRInferencer
+ from .visual_grounding import VisualGroundingInferencer
+ from .visual_question_answering import VisualQuestionAnsweringInferencer
+ task_mapping = {
+ 'Image Classification': ImageClassificationInferencer,
+ 'Image Retrieval': ImageRetrievalInferencer,
+ 'Image Caption': ImageCaptionInferencer,
+ 'Visual Question Answering': VisualQuestionAnsweringInferencer,
+ 'Visual Grounding': VisualGroundingInferencer,
+ 'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
+ 'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
+ 'NLVR': NLVRInferencer,
+ }
+
+ inferencer_type = None
+
+ if metainfo is not None and metainfo.results is not None:
+ tasks = set(result.task for result in metainfo.results)
+ inferencer_type = [
+ task_mapping.get(task) for task in tasks if task in task_mapping
+ ]
+ if len(inferencer_type) > 1:
+ inferencer_names = [cls.__name__ for cls in inferencer_type]
+ warnings.warn('The model supports multiple tasks, auto select '
+ f'{inferencer_names[0]}, you can also use other '
+ f'inferencer {inferencer_names} directly.')
+ inferencer_type = inferencer_type[0]
+
+ if inferencer_type is None:
+ raise NotImplementedError('No available inferencer for the model')
+
+ init_kwargs = {
+ k: kwargs.pop(k)
+ for k in list(kwargs)
+ if k in signature(inferencer_type).parameters.keys()
+ }
+
+ inferencer = inferencer_type(model, **init_kwargs)
+ return inferencer(*args, **kwargs)[0]
diff --git a/mmpretrain/apis/multimodal_retrieval.py b/mmpretrain/apis/multimodal_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb9c859aca309306c1e775b7a03bf3bbc1c7717
--- /dev/null
+++ b/mmpretrain/apis/multimodal_retrieval.py
@@ -0,0 +1,603 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import mmengine
+import numpy as np
+import torch
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import BaseDataset, Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from mmpretrain.utils import track
+from .base import BaseInferencer
+from .base import InputType as ImageType
+from .base import ModelType
+from .model import list_models
+
+
+def filter_transforms(transforms: list, data_info: dict):
+ """Filter pipeline to avoid KeyError with partial data info."""
+ data_info = deepcopy(data_info)
+ filtered_transforms = []
+ for t in transforms:
+ try:
+ data_info = t(data_info)
+ filtered_transforms.append(t)
+ except KeyError:
+ pass
+ return filtered_transforms
+
+
+class TextToImageRetrievalInferencer(BaseInferencer):
+ """The inferencer for text to image retrieval.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``TextToImageRetrievalInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ prototype (str | list | dict | DataLoader | BaseDataset): The images to
+ be retrieved. It can be the following types:
+
+ - str: The directory of the the images.
+ - list: A list of path of the images.
+ - dict: A config dict of the a prototype dataset.
+ - BaseDataset: A prototype dataset.
+ - DataLoader: A data loader to load the prototype data.
+
+ prototype_cache (str, optional): The path of the generated prototype
+ features. If exists, directly load the cache instead of re-generate
+ the prototype features. If not exists, save the generated features
+ to the path. Defaults to None.
+ fast_match (bool): Some algorithms will record extra image features for
+ further matching, which may consume large memory, set True to avoid
+ this behavior. Defaults to True.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import TextToImageRetrievalInferencer
+ >>> inferencer = TextToImageRetrievalInferencer(
+ ... 'blip-base_3rdparty_retrieval',
+ ... prototype='./demo/',
+ ... prototype_cache='t2i_retri.pth')
+ >>> inferencer('A cat and a dog.')[0]
+ {'match_score': tensor(0.3855, device='cuda:0'),
+ 'sample_idx': 1,
+ 'sample': {'img_path': './demo/cat-dog.png'}}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {
+ 'draw_score', 'show_dir', 'show', 'wait_time', 'figsize', 'topk'
+ }
+ postprocess_kwargs: set = {'topk'}
+
+ def __init__(self,
+ model: ModelType,
+ prototype,
+ prototype_cache=None,
+ fast_match=True,
+ prepare_batch_size=8,
+ pretrained: Union[bool, str] = True,
+ device: Union[str, torch.device, None] = None,
+ **kwargs) -> None:
+ super().__init__(
+ model=model, pretrained=pretrained, device=device, **kwargs)
+
+ self.img_pipeline, self.text_pipeline = self.pipeline
+
+ if hasattr(self.model, 'fast_match'):
+ self.model.fast_match = fast_match
+
+ self.prototype_dataset = self._prepare_prototype(
+ prototype, prototype_cache, batch_size=prepare_batch_size)
+
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
+ from mmengine.dataset import DefaultSampler
+ from torch.utils.data import DataLoader
+
+ def build_dataloader(dataset):
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ collate_fn=default_collate,
+ sampler=DefaultSampler(dataset, shuffle=False),
+ persistent_workers=False,
+ )
+
+ if isinstance(prototype, str):
+ # A directory path of images
+ prototype = dict(
+ type='CustomDataset', with_label=False, data_root=prototype)
+
+ if isinstance(prototype, list):
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
+ dataset = BaseDataset(
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
+ dataset.data_list = [{
+ 'sample_idx': i,
+ 'img_path': file
+ } for i, file in enumerate(prototype)]
+ dataset._fully_initialized = True
+ dataloader = build_dataloader(dataset)
+ elif isinstance(prototype, dict):
+ # A config of dataset
+ from mmpretrain.registry import DATASETS
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
+ prototype.setdefault('pipeline', test_pipeline)
+ dataset = DATASETS.build(prototype)
+ dataloader = build_dataloader(dataset)
+ elif isinstance(prototype, list):
+ test_pipeline = [dict(type='LoadImageFromFile'), self.img_pipeline]
+ dataset = BaseDataset(
+ lazy_init=True, serialize_data=False, pipeline=test_pipeline)
+ dataset.data_list = [{
+ 'sample_idx': i,
+ 'img_path': file
+ } for i, file in enumerate(prototype)]
+ dataset._fully_initialized = True
+ dataloader = build_dataloader(dataset)
+ elif isinstance(prototype, DataLoader):
+ dataset = prototype.dataset
+ dataloader = prototype
+ elif isinstance(prototype, BaseDataset):
+ dataset = prototype
+ dataloader = build_dataloader(dataset)
+ else:
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
+
+ if cache is not None and Path(cache).exists():
+ self.prototype = torch.load(cache)
+ else:
+ prototype = []
+ for data_batch in track(dataloader, 'Prepare prototype...'):
+ with torch.no_grad():
+ data_batch = self.model.data_preprocessor(
+ data_batch, False)
+ feats = self.model._run_forward(data_batch, mode='tensor')
+ prototype.append(feats)
+ prototype = {
+ k: torch.cat([d[k] for d in prototype])
+ for k in prototype[0]
+ }
+ self.prototype = prototype
+
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ if cache is None:
+ logger.info('The prototype has been prepared, you can use '
+ '`save_prototype` to dump it into a pickle '
+ 'file for the future usage.')
+ elif not Path(cache).exists():
+ self.save_prototype(cache)
+ logger.info(f'The prototype has been saved at {cache}.')
+
+ return dataset
+
+ def save_prototype(self, path):
+ torch.save(self.prototype, path)
+
+ def __call__(self,
+ inputs: ImageType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (str | array | list): The image path or array, or a list of
+ images.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the long edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the match scores.
+ Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ return super().__call__(inputs, return_datasamples, batch_size,
+ **kwargs)
+
+ @torch.no_grad()
+ def forward(self, data: dict, **kwargs):
+ """Feed the inputs to the model."""
+ data = self.model.data_preprocessor(data, False)
+ data_samples = data['data_samples']
+ feats = self.prototype.copy()
+ feats.update(self.model.extract_feat(data_samples=data_samples))
+ return self.model.predict_all(feats, data_samples, cal_i2t=False)[0]
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
+ img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
+ text_info = {'text': 'example'}
+ img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
+ text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
+ return img_pipeline, text_pipeline
+
+ def preprocess(self, inputs: List[str], batch_size: int = 1):
+
+ def process_text(input_: str):
+ return self.text_pipeline({'text': input_})
+
+ chunked_data = self._get_chunk_data(
+ map(process_text, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[str],
+ preds: List[DataSample],
+ topk: int = 3,
+ figsize: Tuple[int, int] = (16, 9),
+ show: bool = False,
+ wait_time: int = 0,
+ draw_score=True,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (text, data_sample) in enumerate(zip(ori_inputs, preds)):
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_t2i_retrieval(
+ text,
+ data_sample,
+ self.prototype_dataset,
+ topk=topk,
+ fig_cfg=dict(figsize=figsize),
+ draw_score=draw_score,
+ show=show,
+ wait_time=wait_time,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(
+ self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False,
+ topk=1,
+ ) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
+ matches = []
+ for match_score, sample_idx in zip(match_scores, indices):
+ sample = self.prototype_dataset.get_data_info(
+ sample_idx.item())
+ sample_idx = sample.pop('sample_idx')
+ matches.append({
+ 'match_score': match_score,
+ 'sample_idx': sample_idx,
+ 'sample': sample
+ })
+ results.append(matches)
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Text-To-Image Retrieval')
+
+
+class ImageToTextRetrievalInferencer(BaseInferencer):
+ """The inferencer for image to text retrieval.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``ImageToTextRetrievalInferencer.list_models()`` and you can
+ also query it in :doc:`/modelzoo_statistics`.
+ prototype (str | list | dict | DataLoader, BaseDataset): The images to
+ be retrieved. It can be the following types:
+
+ - str: The file path to load the string list.
+ - list: A list of string.
+
+ prototype_cache (str, optional): The path of the generated prototype
+ features. If exists, directly load the cache instead of re-generate
+ the prototype features. If not exists, save the generated features
+ to the path. Defaults to None.
+ fast_match (bool): Some algorithms will record extra image features for
+ further matching, which may consume large memory, set True to avoid
+ this behavior. Defaults to True.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import ImageToTextRetrievalInferencer
+ >>> inferencer = ImageToTextRetrievalInferencer(
+ ... 'blip-base_3rdparty_retrieval',
+ ... prototype=['cat', 'dog', 'snake', 'bird'],
+ ... prototype_cache='i2t_retri.pth')
+ >>> inferencer('demo/bird.JPEG')[0]
+ {'match_score': tensor(0.3855, device='cuda:0'),
+ 'sample_idx': 1,
+ 'sample': {'img_path': './demo/cat-dog.png'}}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {
+ 'draw_score', 'resize', 'show_dir', 'show', 'wait_time', 'topk'
+ }
+ postprocess_kwargs: set = {'topk'}
+
+ def __init__(self,
+ model: ModelType,
+ prototype,
+ prototype_cache=None,
+ fast_match=True,
+ prepare_batch_size=8,
+ pretrained: Union[bool, str] = True,
+ device: Union[str, torch.device, None] = None,
+ **kwargs) -> None:
+ super().__init__(
+ model=model, pretrained=pretrained, device=device, **kwargs)
+
+ self.img_pipeline, self.text_pipeline = self.pipeline
+
+ if hasattr(self.model, 'fast_match'):
+ self.model.fast_match = fast_match
+
+ self.prototype_dataset = self._prepare_prototype(
+ prototype, cache=prototype_cache, batch_size=prepare_batch_size)
+
+ def _prepare_prototype(self, prototype, cache=None, batch_size=8):
+ from mmengine.dataset import DefaultSampler
+ from torch.utils.data import DataLoader
+
+ def build_dataloader(dataset):
+ return DataLoader(
+ [
+ self.text_pipeline({
+ 'sample_idx': i,
+ 'text': text
+ }) for i, text in enumerate(dataset)
+ ],
+ batch_size=batch_size,
+ collate_fn=default_collate,
+ sampler=DefaultSampler(dataset, shuffle=False),
+ persistent_workers=False,
+ )
+
+ if isinstance(prototype, str):
+ # A file path of a list of string
+ dataset = mmengine.list_from_file(prototype)
+ elif mmengine.utils.is_seq_of(prototype, str):
+ dataset = prototype
+ else:
+ raise TypeError(f'Unsupported prototype type {type(prototype)}.')
+
+ dataloader = build_dataloader(dataset)
+
+ if cache is not None and Path(cache).exists():
+ self.prototype = torch.load(cache)
+ else:
+ prototype = []
+ for data_batch in track(dataloader, 'Prepare prototype...'):
+ with torch.no_grad():
+ data_batch = self.model.data_preprocessor(
+ data_batch, False)
+ feats = self.model._run_forward(data_batch, mode='tensor')
+ prototype.append(feats)
+ prototype = {
+ k: torch.cat([d[k] for d in prototype])
+ for k in prototype[0]
+ }
+ self.prototype = prototype
+
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ if cache is None:
+ logger.info('The prototype has been prepared, you can use '
+ '`save_prototype` to dump it into a pickle '
+ 'file for the future usage.')
+ elif not Path(cache).exists():
+ self.save_prototype(cache)
+ logger.info(f'The prototype has been saved at {cache}.')
+
+ return dataset
+
+ def save_prototype(self, path):
+ torch.save(self.prototype, path)
+
+ def __call__(self,
+ inputs: ImageType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (str | array | list): The image path or array, or a list of
+ images.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the long edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the match scores.
+ Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ return super().__call__(inputs, return_datasamples, batch_size,
+ **kwargs)
+
+ @torch.no_grad()
+ def forward(self, data: dict, **kwargs):
+ """Feed the inputs to the model."""
+ data = self.model.data_preprocessor(data, False)
+ feats = self.prototype.copy()
+ feats.update(self.model.extract_feat(images=data['images']))
+ return self.model.predict_all(
+ feats, data['data_samples'], cal_t2i=False)[0]
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ test_transfroms = [TRANSFORMS.build(t) for t in test_pipeline_cfg]
+ img_info = {'img': np.zeros((224, 224, 3), dtype=np.uint8)}
+ text_info = {'text': 'example'}
+ img_pipeline = Compose(filter_transforms(test_transfroms, img_info))
+ text_pipeline = Compose(filter_transforms(test_transfroms, text_info))
+ return img_pipeline, text_pipeline
+
+ def preprocess(self, inputs: List[ImageType], batch_size: int = 1):
+
+ def load_image(input_):
+ img = imread(input_)
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return dict(
+ img=img,
+ img_shape=img.shape[:2],
+ ori_shape=img.shape[:2],
+ )
+
+ pipeline = Compose([load_image, self.img_pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[ImageType],
+ preds: List[DataSample],
+ topk: int = 3,
+ resize: Optional[int] = 224,
+ show: bool = False,
+ wait_time: int = 0,
+ draw_score=True,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_)
+ if isinstance(input_, str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_i2t_retrieval(
+ image,
+ data_sample,
+ self.prototype_dataset,
+ topk=topk,
+ resize=resize,
+ draw_score=draw_score,
+ show=show,
+ wait_time=wait_time,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(
+ self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False,
+ topk=1,
+ ) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ match_scores, indices = torch.topk(data_sample.pred_score, k=topk)
+ matches = []
+ for match_score, sample_idx in zip(match_scores, indices):
+ text = self.prototype_dataset[sample_idx.item()]
+ matches.append({
+ 'match_score': match_score,
+ 'sample_idx': sample_idx,
+ 'text': text
+ })
+ results.append(matches)
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Image-To-Text Retrieval')
diff --git a/mmpretrain/apis/nlvr.py b/mmpretrain/apis/nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9977c3b06f36fa61a3cd2edf36077a993b2030cd
--- /dev/null
+++ b/mmpretrain/apis/nlvr.py
@@ -0,0 +1,150 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer
+from .model import list_models
+
+InputType = Tuple[Union[str, np.ndarray], Union[str, np.ndarray], str]
+InputsType = Union[List[InputType], InputType]
+
+
+class NLVRInferencer(BaseInferencer):
+ """The inferencer for Natural Language for Visual Reasoning.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``NLVRInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+ """
+
+ visualize_kwargs: set = {
+ 'resize', 'draw_score', 'show', 'show_dir', 'wait_time'
+ }
+
+ def __call__(self,
+ inputs: InputsType,
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (tuple, List[tuple]): The input data tuples, every tuple
+ should include three items (left image, right image, text).
+ The image can be a path or numpy array.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the short edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the prediction scores
+ of prediction categories. Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ assert isinstance(inputs, (tuple, list))
+ if isinstance(inputs, tuple):
+ inputs = [inputs]
+ for input_ in inputs:
+ assert isinstance(input_, tuple)
+ assert len(input_) == 3
+
+ return super().__call__(
+ inputs,
+ return_datasamples=return_datasamples,
+ batch_size=batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ assert test_pipeline_cfg[0]['type'] == 'ApplyToList'
+
+ list_pipeline = deepcopy(test_pipeline_cfg[0])
+ if list_pipeline.scatter_key == 'img_path':
+ # Remove `LoadImageFromFile`
+ list_pipeline.transforms.pop(0)
+ list_pipeline.scatter_key = 'img'
+
+ test_pipeline = Compose(
+ [TRANSFORMS.build(list_pipeline)] +
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg[1:]])
+ return test_pipeline
+
+ def preprocess(self, inputs: InputsType, batch_size: int = 1):
+
+ def load_image(input_):
+ img1 = imread(input_[0])
+ img2 = imread(input_[1])
+ text = input_[2]
+ if img1 is None:
+ raise ValueError(f'Failed to read image {input_[0]}.')
+ if img2 is None:
+ raise ValueError(f'Failed to read image {input_[1]}.')
+ return dict(
+ img=[img1, img2],
+ img_shape=[img1.shape[:2], img2.shape[:2]],
+ ori_shape=[img1.shape[:2], img2.shape[:2]],
+ text=text,
+ )
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def postprocess(self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ pred_scores = data_sample.pred_score
+ pred_score = float(torch.max(pred_scores).item())
+ pred_label = torch.argmax(pred_scores).item()
+ result = {
+ 'pred_scores': pred_scores.detach().cpu().numpy(),
+ 'pred_label': pred_label,
+ 'pred_score': pred_score,
+ }
+ results.append(result)
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='NLVR')
diff --git a/mmpretrain/apis/utils.py b/mmpretrain/apis/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..83e76325472f6925f78c746e3a10f3a58b0e6de4
--- /dev/null
+++ b/mmpretrain/apis/utils.py
@@ -0,0 +1,270 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from collections import defaultdict
+from contextlib import contextmanager
+from itertools import chain
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from mmpretrain.utils import require
+
+
+@require('torch>=1.9.0', 'https://pytorch.org/get-started/locally/')
+@require('accelerate')
+def dispatch_model(
+ model,
+ device_map: Union[str, dict],
+ max_memory: Optional[dict] = None,
+ no_split_module_classes: Optional[List[str]] = None,
+ offload_folder: str = None,
+ offload_buffers: bool = False,
+ preload_module_classes: Optional[List[str]] = None,
+):
+ """Split and dispatch a model across devices.
+
+ The function depends on the `accelerate` package. Refers to
+ https://huggingface.co/docs/accelerate/main/en/usage_guides/big_modeling
+
+ Args:
+ model (torch.nn.Module): The model to dispatch.
+ device_map (str | dict | None): A map that specifies where each
+ submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every
+ submodule of it will be sent to the same device. You can use
+ `device_map="auto"` to automatically generate the device map.
+ Defaults to None.
+ max_memory (dict | None): A dictionary device identifier to maximum
+ memory. Will default to the maximum memory available for each GPU
+ and the available CPU RAM if unset. Defaults to None.
+ no_split_module_classes (List[str] | None): A list of layer class names
+ that should never be split across device (for instance any layer
+ that has a residual connection). If None, try to get the settings
+ from the model class. Defaults to None.
+ offload_folder (str | None): If the `device_map` contains any value
+ `"disk"`, the folder where we will offload weights.
+ offload_buffers (bool): In the layers that are offloaded on the CPU
+ or the hard drive, whether or not to offload the buffers as
+ well as the parameters. Defaults to False.
+ preload_module_classes (List[str] | None): A list of classes whose
+ instances should load all their weights (even in the submodules) at
+ the beginning of the forward. This should only be used for classes
+ that have submodules which are registered but not called directly
+ during the forward, for instance if a `dense` linear layer is
+ registered, but at forward, `dense.weight` and `dense.bias` are
+ used in some operations instead of calling `dense` directly.
+ Defaults to None.
+ """
+ from accelerate import dispatch_model, infer_auto_device_map
+
+ # Check valid device_map string.
+ valid_map_option = ['auto', 'balanced', 'balanced_low_0', 'sequential']
+ if isinstance(device_map, str) and device_map not in valid_map_option:
+ raise ValueError('If passing a string for `device_map`, please choose '
+ f'from {valid_map_option}.')
+
+ # Generate device map automatically
+ if isinstance(device_map, str):
+ if no_split_module_classes is None:
+ no_split_module_classes = getattr(model, '_no_split_modules', None)
+ if no_split_module_classes is None:
+ raise ValueError(f'{model.__class__.__name__} does not support '
+ f"`device_map='{device_map}'` yet.")
+
+ if device_map != 'sequential':
+ from accelerate.utils import get_balanced_memory
+ max_memory = get_balanced_memory(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=no_split_module_classes,
+ dtype=None,
+ low_zero=(device_map == 'balanced_low_0'),
+ )
+ max_memory[0] *= 0.9
+ device_map = infer_auto_device_map(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=no_split_module_classes,
+ dtype=None,
+ )
+
+ if 'disk' in device_map.values():
+ if offload_folder is None:
+ raise ValueError(
+ 'The current `device_map` had weights offloaded to the disk. '
+ 'Please provide an `offload_folder` for them.')
+ os.makedirs(offload_folder, exist_ok=True)
+
+ main_device = next(
+ (d for d in device_map.values() if d not in ['cpu', 'disk']), 'cpu')
+
+ model = dispatch_model(
+ model,
+ device_map=device_map,
+ main_device=main_device,
+ offload_dir=offload_folder,
+ offload_buffers=offload_buffers,
+ preload_module_classes=preload_module_classes,
+ )
+ if hasattr(model, 'data_preprocessor'):
+ model.data_preprocessor._device = torch.device(main_device)
+ return model
+
+
+@contextmanager
+def init_empty_weights(include_buffers: bool = False):
+ """A context manager under which models are initialized with all parameters
+ on the meta device.
+
+ With this context manager, we can create an empty model. Useful when just
+ initializing the model would blow the available RAM.
+
+ Besides move the parameters to meta device, this method will also avoid
+ load checkpoint from `mmengine.runner.load_checkpoint` and
+ `transformers.PreTrainedModel.from_pretrained`.
+
+ Modified from https://github.com/huggingface/accelerate
+
+ Args:
+ include_buffers (bool): Whether put all buffers on the meta device
+ during initialization.
+ """
+ device = torch.device('meta')
+
+ # move parameter and buffer to meta device
+ old_register_parameter = nn.Module.register_parameter
+ if include_buffers:
+ old_register_buffer = nn.Module.register_buffer
+ # See https://github.com/huggingface/accelerate/pull/699
+ tensor_constructors_to_patch = {
+ torch_function_name: getattr(torch, torch_function_name)
+ for torch_function_name in ['empty', 'zeros', 'ones', 'full']
+ }
+
+ def register_parameter(module, name, param):
+ old_register_parameter(module, name, param)
+ if param is not None:
+ param_cls = type(module._parameters[name])
+ kwargs = module._parameters[name].__dict__
+ module._parameters[name] = param_cls(
+ module._parameters[name].to(device), **kwargs)
+
+ def register_buffer(module, name, buffer, *args, **kwargs):
+ old_register_buffer(module, name, buffer, *args, **kwargs)
+ if buffer is not None:
+ module._buffers[name] = module._buffers[name].to(device)
+
+ def patch_tensor_constructor(fn):
+
+ def wrapper(*args, **kwargs):
+ kwargs['device'] = device
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ # Patch load_checkpoint
+ import mmengine.runner.checkpoint as mmengine_load
+ old_load_checkpoint = mmengine_load.load_checkpoint
+
+ def patch_load_checkpoint(*args, **kwargs):
+ return {}
+
+ # Patch transformers from pretrained
+ try:
+ from transformers import PreTrainedModel
+ from transformers.models.auto.auto_factory import (AutoConfig,
+ _BaseAutoModelClass)
+ with_transformers = True
+ except ImportError:
+ with_transformers = False
+
+ @classmethod
+ def patch_auto_model(cls, pretrained_model_name_or_path, *model_args,
+ **kwargs):
+ cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path,
+ *model_args, **kwargs)
+ return cls.from_config(cfg)
+
+ @classmethod
+ def patch_pretrained_model(cls, pretrained_model_name_or_path, *model_args,
+ **kwargs):
+ cfg = cls.config_class.from_pretrained(pretrained_model_name_or_path,
+ *model_args, **kwargs)
+ return cls(cfg)
+
+ if with_transformers:
+ old_pretrained_model = PreTrainedModel.from_pretrained
+ old_auto_model = _BaseAutoModelClass.from_pretrained
+
+ try:
+ nn.Module.register_parameter = register_parameter
+ mmengine_load.load_checkpoint = patch_load_checkpoint
+ if with_transformers:
+ PreTrainedModel.from_pretrained = patch_pretrained_model
+ _BaseAutoModelClass.from_pretrained = patch_auto_model
+ if include_buffers:
+ nn.Module.register_buffer = register_buffer
+ for func in tensor_constructors_to_patch.keys():
+ tensor_constructor = patch_tensor_constructor(
+ getattr(torch, func))
+ setattr(torch, func, tensor_constructor)
+ yield
+ finally:
+ nn.Module.register_parameter = old_register_parameter
+ mmengine_load.load_checkpoint = old_load_checkpoint
+ if with_transformers:
+ PreTrainedModel.from_pretrained = old_pretrained_model
+ _BaseAutoModelClass.from_pretrained = old_auto_model
+ if include_buffers:
+ nn.Module.register_buffer = old_register_buffer
+ for func, ori in tensor_constructors_to_patch.items():
+ setattr(torch, func, ori)
+
+
+def compute_module_sizes(
+ model: nn.Module,
+ dtype: Union[str, torch.dtype, None] = None,
+ special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None):
+ """Compute the size of each submodule of a given model."""
+
+ def get_dtype(dtype):
+ if isinstance(dtype, str):
+ dtype = getattr(torch, dtype)
+ if dtype is not None:
+ assert issubclass(dtype, torch.dtype)
+ return dtype
+
+ def dtype_bytes(dtype: torch.dtype):
+ if dtype is torch.bool:
+ return 1
+ if dtype.is_floating_point:
+ return torch.finfo(dtype).bits / 8
+ else:
+ return torch.iinfo(dtype).bits / 8
+
+ if dtype is not None:
+ dtype = get_dtype(dtype)
+ dtype_size = dtype_bytes(dtype)
+
+ if special_dtypes is not None:
+ special_dtypes = {
+ key: dtype_bytes(dtype)
+ for key, dtype in special_dtypes.items()
+ }
+
+ module_sizes = defaultdict(int)
+ for name, tensor in chain(
+ model.named_parameters(recurse=True),
+ model.named_buffers(recurse=True)):
+ if special_dtypes is not None and name in special_dtypes:
+ size = tensor.numel() * special_dtypes[name]
+ elif dtype is None:
+ size = tensor.numel() * tensor.element_size()
+ else:
+ size = tensor.numel() * min(dtype_size, tensor.element_size())
+ name_parts = name.split('.')
+ for idx in range(len(name_parts) + 1):
+ module_sizes['.'.join(name_parts[:idx])] += size
+
+ return module_sizes
diff --git a/mmpretrain/apis/visual_grounding.py b/mmpretrain/apis/visual_grounding.py
new file mode 100644
index 0000000000000000000000000000000000000000..0153d56f5ca10a32e9fd2ccabb0d15c1135e213d
--- /dev/null
+++ b/mmpretrain/apis/visual_grounding.py
@@ -0,0 +1,182 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer
+from .model import list_models
+
+
+class VisualGroundingInferencer(BaseInferencer):
+ """The inferencer for visual grounding.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``VisualGroundingInferencer.list_models()`` and you can also
+ query it in :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import VisualGroundingInferencer
+ >>> inferencer = VisualGroundingInferencer('ofa-base_3rdparty_refcoco')
+ >>> inferencer('demo/cat-dog.png', 'dog')[0]
+ {'pred_bboxes': tensor([[ 36.6000, 29.6000, 355.8000, 395.2000]])}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {
+ 'resize', 'show', 'show_dir', 'wait_time', 'line_width', 'bbox_color'
+ }
+
+ def __call__(self,
+ images: Union[str, np.ndarray, list],
+ texts: Union[str, list],
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ images (str | array | list): The image path or array, or a list of
+ images.
+ texts (str | list): The text to do visual grounding.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ resize (int, optional): Resize the short edge of the image to the
+ specified length before visualization. Defaults to None.
+ draw_score (bool): Whether to draw the prediction scores
+ of prediction categories. Defaults to True.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+ line_width (int): The line width of the bbox. Defaults to 3.
+ bbox_color (str | tuple): The color of the bbox.
+ Defaults to 'green'.
+
+ Returns:
+ list: The inference results.
+ """
+ if not isinstance(images, (list, tuple)):
+ assert isinstance(texts, str)
+ inputs = [{'img': images, 'text': texts}]
+ else:
+ inputs = []
+ for i in range(len(images)):
+ input_ = {'img': images[i], 'text': texts[i]}
+ inputs.append(input_)
+
+ return super().__call__(inputs, return_datasamples, batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[dict], batch_size: int = 1):
+
+ def load_image(input_: dict):
+ img = imread(input_['img'])
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return {**input_, 'img': img}
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[dict],
+ preds: List[DataSample],
+ show: bool = False,
+ wait_time: int = 0,
+ resize: Optional[int] = None,
+ line_width: int = 3,
+ bbox_color: Union[str, tuple] = 'green',
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_['img'])
+ if isinstance(input_['img'], str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_['img']).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_visual_grounding(
+ image,
+ data_sample,
+ resize=resize,
+ show=show,
+ wait_time=wait_time,
+ line_width=line_width,
+ bbox_color=bbox_color,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ results.append({'pred_bboxes': data_sample.get('pred_bboxes')})
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Visual Grounding')
diff --git a/mmpretrain/apis/visual_question_answering.py b/mmpretrain/apis/visual_question_answering.py
new file mode 100644
index 0000000000000000000000000000000000000000..616e1edc66709401df83cb5253590325e727aa98
--- /dev/null
+++ b/mmpretrain/apis/visual_question_answering.py
@@ -0,0 +1,183 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+from mmcv.image import imread
+from mmengine.config import Config
+from mmengine.dataset import Compose, default_collate
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample
+from .base import BaseInferencer
+from .model import list_models
+
+
+class VisualQuestionAnsweringInferencer(BaseInferencer):
+ """The inferencer for visual question answering.
+
+ Args:
+ model (BaseModel | str | Config): A model name or a path to the config
+ file, or a :obj:`BaseModel` object. The model name can be found
+ by ``VisualQuestionAnsweringInferencer.list_models()`` and you can
+ also query it in :doc:`/modelzoo_statistics`.
+ pretrained (str, optional): Path to the checkpoint. If None, it will
+ try to find a pre-defined weight from the model you specified
+ (only work if the ``model`` is a model name). Defaults to None.
+ device (str, optional): Device to run inference. If None, the available
+ device will be automatically used. Defaults to None.
+ **kwargs: Other keyword arguments to initialize the model (only work if
+ the ``model`` is a model name).
+
+ Example:
+ >>> from mmpretrain import VisualQuestionAnsweringInferencer
+ >>> inferencer = VisualQuestionAnsweringInferencer('ofa-base_3rdparty-zeroshot_vqa')
+ >>> inferencer('demo/cat-dog.png', "What's the animal next to the dog?")[0]
+ {'question': "What's the animal next to the dog?", 'pred_answer': 'cat'}
+ """ # noqa: E501
+
+ visualize_kwargs: set = {'resize', 'show', 'show_dir', 'wait_time'}
+
+ def __call__(self,
+ images: Union[str, np.ndarray, list],
+ questions: Union[str, list],
+ return_datasamples: bool = False,
+ batch_size: int = 1,
+ objects: Optional[List[str]] = None,
+ **kwargs) -> dict:
+ """Call the inferencer.
+
+ Args:
+ images (str | array | list): The image path or array, or a list of
+ images.
+ questions (str | list): The question to the correspondding image.
+ return_datasamples (bool): Whether to return results as
+ :obj:`DataSample`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ objects (List[List[str]], optional): Some algorithms like OFA
+ fine-tuned VQA models requires extra object description list
+ for every image. Defaults to None.
+ resize (int, optional): Resize the short edge of the image to the
+ specified length before visualization. Defaults to None.
+ show (bool): Whether to display the visualization result in a
+ window. Defaults to False.
+ wait_time (float): The display time (s). Defaults to 0, which means
+ "forever".
+ show_dir (str, optional): If not None, save the visualization
+ results in the specified directory. Defaults to None.
+
+ Returns:
+ list: The inference results.
+ """
+ if not isinstance(images, (list, tuple)):
+ assert isinstance(questions, str)
+ inputs = [{'img': images, 'question': questions}]
+ if objects is not None:
+ assert isinstance(objects[0], str)
+ inputs[0]['objects'] = objects
+ else:
+ inputs = []
+ for i in range(len(images)):
+ input_ = {'img': images[i], 'question': questions[i]}
+ if objects is not None:
+ input_['objects'] = objects[i]
+ inputs.append(input_)
+
+ return super().__call__(inputs, return_datasamples, batch_size,
+ **kwargs)
+
+ def _init_pipeline(self, cfg: Config) -> Callable:
+ test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
+ from mmpretrain.datasets import remove_transform
+
+ # Image loading is finished in `self.preprocess`.
+ test_pipeline_cfg = remove_transform(test_pipeline_cfg,
+ 'LoadImageFromFile')
+ test_pipeline = Compose(
+ [TRANSFORMS.build(t) for t in test_pipeline_cfg])
+ return test_pipeline
+
+ def preprocess(self, inputs: List[dict], batch_size: int = 1):
+
+ def load_image(input_: dict):
+ img = imread(input_['img'])
+ if img is None:
+ raise ValueError(f'Failed to read image {input_}.')
+ return {**input_, 'img': img}
+
+ pipeline = Compose([load_image, self.pipeline])
+
+ chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
+ yield from map(default_collate, chunked_data)
+
+ def visualize(self,
+ ori_inputs: List[dict],
+ preds: List[DataSample],
+ show: bool = False,
+ wait_time: int = 0,
+ resize: Optional[int] = None,
+ show_dir=None):
+ if not show and show_dir is None:
+ return None
+
+ if self.visualizer is None:
+ from mmpretrain.visualization import UniversalVisualizer
+ self.visualizer = UniversalVisualizer()
+
+ visualization = []
+ for i, (input_, data_sample) in enumerate(zip(ori_inputs, preds)):
+ image = imread(input_['img'])
+ if isinstance(input_['img'], str):
+ # The image loaded from path is BGR format.
+ image = image[..., ::-1]
+ name = Path(input_['img']).stem
+ else:
+ name = str(i)
+
+ if show_dir is not None:
+ show_dir = Path(show_dir)
+ show_dir.mkdir(exist_ok=True)
+ out_file = str((show_dir / name).with_suffix('.png'))
+ else:
+ out_file = None
+
+ self.visualizer.visualize_vqa(
+ image,
+ data_sample,
+ resize=resize,
+ show=show,
+ wait_time=wait_time,
+ name=name,
+ out_file=out_file)
+ visualization.append(self.visualizer.get_image())
+ if show:
+ self.visualizer.close()
+ return visualization
+
+ def postprocess(self,
+ preds: List[DataSample],
+ visualization: List[np.ndarray],
+ return_datasamples=False) -> dict:
+ if return_datasamples:
+ return preds
+
+ results = []
+ for data_sample in preds:
+ results.append({
+ 'question': data_sample.get('question'),
+ 'pred_answer': data_sample.get('pred_answer'),
+ })
+
+ return results
+
+ @staticmethod
+ def list_models(pattern: Optional[str] = None):
+ """List all available model names.
+
+ Args:
+ pattern (str | None): A wildcard pattern to match model names.
+
+ Returns:
+ List[str]: a list of model names.
+ """
+ return list_models(pattern=pattern, task='Visual Question Answering')
diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d074008cc204f4ac486dc04fb3f1c638fb9e161
--- /dev/null
+++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.dataset import DefaultSampler
+
+from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
+ PackInputs, RandomFlip, RandomResizedCrop,
+ ResizeEdge)
+from mmpretrain.evaluation import Accuracy
+
+# dataset settings
+dataset_type = ImageNet
+data_preprocessor = dict(
+ num_classes=1000,
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+train_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(type=RandomResizedCrop, scale=224),
+ dict(type=RandomFlip, prob=0.5, direction='horizontal'),
+ dict(type=PackInputs),
+]
+
+test_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(type=ResizeEdge, scale=256, edge='short'),
+ dict(type=CenterCrop, crop_size=224),
+ dict(type=PackInputs),
+]
+
+train_dataloader = dict(
+ batch_size=32,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type=DefaultSampler, shuffle=True),
+)
+
+val_dataloader = dict(
+ batch_size=32,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/val.txt',
+ data_prefix='val',
+ pipeline=test_pipeline),
+ sampler=dict(type=DefaultSampler, shuffle=False),
+)
+val_evaluator = dict(type=Accuracy, topk=(1, 5))
+
+# If you want standard test, please manually configure the test dataset
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b687a06fef86827a76a472371f4de7dc2a5f2ac1
--- /dev/null
+++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_simclr.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmcv.transforms import (LoadImageFromFile, RandomApply, RandomFlip,
+ RandomGrayscale)
+from mmengine.dataset import DefaultSampler, default_collate
+
+from mmpretrain.datasets import (ColorJitter, GaussianBlur, ImageNet,
+ MultiView, PackInputs, RandomResizedCrop)
+from mmpretrain.models import SelfSupDataPreprocessor
+
+# dataset settings
+dataset_type = 'ImageNet'
+data_root = 'data/imagenet/'
+data_preprocessor = dict(
+ type=SelfSupDataPreprocessor,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True)
+
+view_pipeline = [
+ dict(type=RandomResizedCrop, scale=224, backend='pillow'),
+ dict(type=RandomFlip, prob=0.5),
+ dict(
+ type=RandomApply,
+ transforms=[
+ dict(
+ type=ColorJitter,
+ brightness=0.8,
+ contrast=0.8,
+ saturation=0.8,
+ hue=0.2)
+ ],
+ prob=0.8),
+ dict(
+ type=RandomGrayscale,
+ prob=0.2,
+ keep_channels=True,
+ channel_weights=(0.114, 0.587, 0.2989)),
+ dict(
+ type=GaussianBlur,
+ magnitude_range=(0.1, 2.0),
+ magnitude_std='inf',
+ prob=0.5),
+]
+
+train_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(type=MultiView, num_views=2, transforms=[view_pipeline]),
+ dict(type=PackInputs)
+]
+
+train_dataloader = dict(
+ batch_size=32,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ collate_fn=dict(type=default_collate),
+ dataset=dict(
+ type=ImageNet,
+ data_root=data_root,
+ ann_file='meta/train.txt',
+ data_prefix=dict(img_path='train/'),
+ pipeline=train_pipeline))
diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64d5eac39902bdf8c95a0e7f7d589d069951df5
--- /dev/null
+++ b/mmpretrain/configs/_base_/datasets/imagenet_bs512_mae.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmcv.transforms import LoadImageFromFile, RandomFlip
+from mmengine.dataset.sampler import DefaultSampler
+
+from mmpretrain.datasets import ImageNet, PackInputs, RandomResizedCrop
+from mmpretrain.models import SelfSupDataPreprocessor
+
+# dataset settings
+dataset_type = 'ImageNet'
+data_root = 'data/imagenet/'
+data_preprocessor = dict(
+ type=SelfSupDataPreprocessor,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True)
+
+train_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(
+ type=RandomResizedCrop,
+ scale=224,
+ crop_ratio_range=(0.2, 1.0),
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type=RandomFlip, prob=0.5),
+ dict(type=PackInputs)
+]
+
+train_dataloader = dict(
+ batch_size=512,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type=ImageNet,
+ data_root=data_root,
+ ann_file='meta/train.txt',
+ data_prefix=dict(img_path='train/'),
+ pipeline=train_pipeline))
diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py
new file mode 100644
index 0000000000000000000000000000000000000000..85aeb1e2c131109f3f6d75d21e2cc1c782c82b7f
--- /dev/null
+++ b/mmpretrain/configs/_base_/datasets/imagenet_bs64_swin_384.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.dataset import DefaultSampler
+
+from mmpretrain.datasets import (ImageNet, LoadImageFromFile, PackInputs,
+ RandomFlip, RandomResizedCrop, Resize)
+from mmpretrain.evaluation import Accuracy
+
+# dataset settings
+dataset_type = ImageNet
+data_preprocessor = dict(
+ num_classes=1000,
+ # RGB format normalization parameters
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ # convert image from BGR to RGB
+ to_rgb=True,
+)
+
+train_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(
+ type=RandomResizedCrop,
+ scale=384,
+ backend='pillow',
+ interpolation='bicubic'),
+ dict(type=RandomFlip, prob=0.5, direction='horizontal'),
+ dict(type=PackInputs),
+]
+
+test_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(type=Resize, scale=384, backend='pillow', interpolation='bicubic'),
+ dict(type=PackInputs),
+]
+
+train_dataloader = dict(
+ batch_size=64,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/train.txt',
+ data_prefix='train',
+ pipeline=train_pipeline),
+ sampler=dict(type=DefaultSampler, shuffle=True),
+)
+
+val_dataloader = dict(
+ batch_size=64,
+ num_workers=5,
+ dataset=dict(
+ type=dataset_type,
+ data_root='data/imagenet',
+ ann_file='meta/val.txt',
+ data_prefix='val',
+ pipeline=test_pipeline),
+ sampler=dict(type=DefaultSampler, shuffle=False),
+)
+val_evaluator = dict(type=Accuracy, topk=(1, 5))
+
+# If you want standard test, please manually configure the test dataset
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
diff --git a/mmpretrain/configs/_base_/default_runtime.py b/mmpretrain/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c748eb84b3e50d7c6b30efaa87cd3c1f2f1827
--- /dev/null
+++ b/mmpretrain/configs/_base_/default_runtime.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.visualization import LocalVisBackend
+
+from mmpretrain.engine.hooks import VisualizationHook
+from mmpretrain.visualization import UniversalVisualizer
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+
+ # print log every 100 iterations.
+ logger=dict(type=LoggerHook, interval=100),
+
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+
+ # save checkpoint per epoch.
+ checkpoint=dict(type=CheckpointHook, interval=1),
+
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+
+ # validation results visualization, set True to enable it.
+ visualization=dict(type=VisualizationHook, enable=False),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+vis_backends = [dict(type=LocalVisBackend)]
+visualizer = dict(type=UniversalVisualizer, vis_backends=vis_backends)
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# Do not need to specify default_scope with new config. Therefore set it to
+# None to avoid BC-breaking.
+default_scope = None
diff --git a/mmpretrain/configs/_base_/models/convnext_base.py b/mmpretrain/configs/_base_/models/convnext_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6315b2f1966d2484739087e1e131fe8dd9a2ad56
--- /dev/null
+++ b/mmpretrain/configs/_base_/models/convnext_base.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.model import TruncNormalInit
+
+from mmpretrain.models import (ConvNeXt, CutMix, ImageClassifier,
+ LabelSmoothLoss, LinearClsHead, Mixup)
+
+# Model settings
+model = dict(
+ type=ImageClassifier,
+ backbone=dict(type=ConvNeXt, arch='base', drop_path_rate=0.5),
+ head=dict(
+ type=LinearClsHead,
+ num_classes=1000,
+ in_channels=1024,
+ loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
+ init_cfg=None,
+ ),
+ init_cfg=dict(
+ type=TruncNormalInit, layer=['Conv2d', 'Linear'], std=.02, bias=0.),
+ train_cfg=dict(augments=[
+ dict(type=Mixup, alpha=0.8),
+ dict(type=CutMix, alpha=1.0),
+ ]),
+)
diff --git a/mmpretrain/configs/_base_/models/mae_vit_base_p16.py b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py
new file mode 100644
index 0000000000000000000000000000000000000000..9347d1e8810e553ef5563a96198794ec139ea3a4
--- /dev/null
+++ b/mmpretrain/configs/_base_/models/mae_vit_base_p16.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmpretrain.models import (MAE, MAEPretrainDecoder, MAEPretrainHead,
+ MAEViT, PixelReconstructionLoss)
+
+# model settings
+model = dict(
+ type=MAE,
+ backbone=dict(type=MAEViT, arch='b', patch_size=16, mask_ratio=0.75),
+ neck=dict(
+ type=MAEPretrainDecoder,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ decoder_embed_dim=512,
+ decoder_depth=8,
+ decoder_num_heads=16,
+ mlp_ratio=4.,
+ ),
+ head=dict(
+ type=MAEPretrainHead,
+ norm_pix=True,
+ patch_size=16,
+ loss=dict(type=PixelReconstructionLoss, criterion='L2')),
+ init_cfg=[
+ dict(type='Xavier', layer='Linear', distribution='uniform'),
+ dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
+ ])
diff --git a/mmpretrain/configs/_base_/models/resnet18.py b/mmpretrain/configs/_base_/models/resnet18.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b8f65148611c5602858b875b9be89b31f225cb
--- /dev/null
+++ b/mmpretrain/configs/_base_/models/resnet18.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
+ ImageClassifier, LinearClsHead, ResNet)
+
+# model settings
+model = dict(
+ type=ImageClassifier,
+ backbone=dict(
+ type=ResNet,
+ depth=18,
+ num_stages=4,
+ out_indices=(3, ),
+ style='pytorch'),
+ neck=dict(type=GlobalAveragePooling),
+ head=dict(
+ type=LinearClsHead,
+ num_classes=1000,
+ in_channels=512,
+ loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
+ topk=(1, 5),
+ ))
diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..60ccaa0e25ec69aa618430f51a60d949506fc406
--- /dev/null
+++ b/mmpretrain/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.optim import CosineAnnealingLR, LinearLR
+from torch.optim import AdamW
+
+# for batch in each gpu is 128, 8 gpu
+# lr = 5e-4 * 128 * 8 / 512 = 0.001
+optim_wrapper = dict(
+ optimizer=dict(
+ type=AdamW,
+ lr=5e-4 * 1024 / 512,
+ weight_decay=0.05,
+ eps=1e-8,
+ betas=(0.9, 0.999)),
+ paramwise_cfg=dict(
+ norm_decay_mult=0.0,
+ bias_decay_mult=0.0,
+ flat_decay_mult=0.0,
+ custom_keys={
+ '.absolute_pos_embed': dict(decay_mult=0.0),
+ '.relative_position_bias_table': dict(decay_mult=0.0)
+ }),
+)
+
+# learning policy
+param_scheduler = [
+ # warm up learning rate scheduler
+ dict(
+ type=LinearLR,
+ start_factor=1e-3,
+ by_epoch=True,
+ end=20,
+ # update by iter
+ convert_to_iter_based=True),
+ # main learning rate scheduler
+ dict(type=CosineAnnealingLR, eta_min=1e-5, by_epoch=True, begin=20)
+]
+
+# train, val, test setting
+train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
+val_cfg = dict()
+test_cfg = dict()
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR,
+# based on the actual training batch size.
+auto_scale_lr = dict(base_batch_size=1024)
diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py
new file mode 100644
index 0000000000000000000000000000000000000000..95afa2ad292c277a84aa274786ee34a9d6b8b0ef
--- /dev/null
+++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.optim import MultiStepLR
+from torch.optim import SGD
+
+# optimizer
+optim_wrapper = dict(
+ optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001))
+
+# learning policy
+param_scheduler = dict(
+ type=MultiStepLR, by_epoch=True, milestones=[30, 60, 90], gamma=0.1)
+
+# train, val, test setting
+train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
+val_cfg = dict()
+test_cfg = dict()
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR,
+# based on the actual training batch size.
+auto_scale_lr = dict(base_batch_size=256)
diff --git a/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c7e6171e2aeb20a94277e7ca4d02b2598d73b8e
--- /dev/null
+++ b/mmpretrain/configs/_base_/schedules/imagenet_lars_coslr_200e.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
+from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR
+from mmengine.runner.loops import EpochBasedTrainLoop
+
+from mmpretrain.engine.optimizers.lars import LARS
+
+# optimizer wrapper
+optim_wrapper = dict(
+ type=OptimWrapper,
+ optimizer=dict(type=LARS, lr=4.8, weight_decay=1e-6, momentum=0.9))
+
+# learning rate scheduler
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-4,
+ by_epoch=True,
+ begin=0,
+ end=10,
+ convert_to_iter_based=True),
+ dict(type=CosineAnnealingLR, T_max=190, by_epoch=True, begin=10, end=200)
+]
+
+# runtime settings
+train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=200)
diff --git a/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe9c329abba18ecb5bad1875090f7de667a77391
--- /dev/null
+++ b/mmpretrain/configs/beit/beit_beit_base_p16_8xb256_amp_coslr_300e_in1k.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.default_runtime import *
+
+from mmengine.dataset import DefaultSampler, default_collate
+from mmengine.hooks import CheckpointHook
+from mmengine.model import ConstantInit, PretrainedInit, TruncNormalInit
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from mmengine.runner import EpochBasedTrainLoop
+from torch.optim import AdamW
+
+from mmpretrain.datasets import (BEiTMaskGenerator, ColorJitter, ImageNet,
+ LoadImageFromFile, PackInputs, RandomFlip,
+ RandomResizedCropAndInterpolationWithTwoPic)
+from mmpretrain.models import (BEiT, BEiTPretrainViT, BEiTV1Head,
+ CrossEntropyLoss, DALLEEncoder,
+ TwoNormDataPreprocessor)
+
+# dataset settings
+dataset_type = ImageNet
+data_root = 'data/imagenet/'
+data_preprocessor = dict(
+ type=TwoNormDataPreprocessor,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ second_mean=[-31.875, -31.875, -31.875],
+ second_std=[318.75, 318.75, 318.75],
+ to_rgb=True)
+
+train_pipeline = [
+ dict(type=LoadImageFromFile),
+ dict(
+ type=ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4,
+ hue=0.),
+ dict(type=RandomFlip, prob=0.5, direction='horizontal'),
+ dict(
+ type=RandomResizedCropAndInterpolationWithTwoPic,
+ size=224,
+ second_size=112,
+ interpolation='bicubic',
+ second_interpolation='lanczos',
+ scale=(0.08, 1.0)),
+ dict(
+ type=BEiTMaskGenerator,
+ input_size=(14, 14),
+ num_masking_patches=75,
+ max_num_patches=None,
+ min_num_patches=16),
+ dict(type=PackInputs)
+]
+train_dataloader = dict(
+ batch_size=256,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type=DefaultSampler, shuffle=True),
+ collate_fn=dict(type=default_collate),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='meta/train.txt',
+ data_prefix=dict(img_path='train/'),
+ pipeline=train_pipeline))
+
+# model settings
+model = dict(
+ type=BEiT,
+ backbone=dict(
+ type=BEiTPretrainViT,
+ arch='base',
+ patch_size=16,
+ drop_path_rate=0.1,
+ final_norm=True,
+ out_type='raw',
+ layer_scale_init_value=0.1,
+ init_cfg=[
+ dict(type=TruncNormalInit, std=0.02, layer='Linear'),
+ dict(type=TruncNormalInit, std=0.02, layer='Conv2d'),
+ dict(type=ConstantInit, layer='LayerNorm', val=1.0, bias=0.0)
+ ]),
+ neck=None,
+ head=dict(
+ type=BEiTV1Head,
+ embed_dims=768,
+ num_embed=8192,
+ loss=dict(type=CrossEntropyLoss)),
+ target_generator=dict(
+ type=DALLEEncoder,
+ init_cfg=dict(
+ type=PretrainedInit,
+ checkpoint= # noqa: E251
+ 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501
+ )))
+
+# optimizer wrapper
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ loss_scale='dynamic',
+ optimizer=dict(
+ type=AdamW, lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05),
+ clip_grad=dict(max_norm=3.0),
+ paramwise_cfg=dict(
+ custom_keys={
+ # the following configurations are designed for BEiT
+ '.ln': dict(decay_mult=0.0),
+ '.bias': dict(decay_mult=0.0),
+ 'q_bias': dict(decay_mult=0.0),
+ 'v_bias': dict(decay_mult=0.0),
+ '.cls_token': dict(decay_mult=0.0),
+ '.pos_embed': dict(decay_mult=0.0),
+ '.gamma': dict(decay_mult=0.0),
+ }))
+
+# learning rate scheduler
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-4,
+ by_epoch=True,
+ begin=0,
+ end=10,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=1e-5,
+ by_epoch=True,
+ begin=10,
+ end=300,
+ convert_to_iter_based=True)
+]
+
+# runtime settings
+train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300)
+default_hooks.update(
+ # only keeps the latest 3 checkpoints
+ checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3))
+
+randomness.update(seed=0, diff_rank_seed=True)
+
+find_unused_parameters = True
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+auto_scale_lr = dict(base_batch_size=2048)
diff --git a/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py
new file mode 100644
index 0000000000000000000000000000000000000000..46be2ca6dcb0636469ca19dda51affd003a8b812
--- /dev/null
+++ b/mmpretrain/configs/convnext/convnext_base_32xb128_in1k_384px.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.datasets.imagenet_bs64_swin_384 import *
+ from .._base_.default_runtime import *
+ from .._base_.models.convnext_base import *
+ from .._base_.schedules.imagenet_bs1024_adamw_swin import *
+
+from mmpretrain.engine import EMAHook
+
+# dataset setting
+train_dataloader.update(batch_size=128)
+
+# schedule setting
+optim_wrapper.update(
+ optimizer=dict(lr=4e-3),
+ clip_grad=dict(max_norm=5.0),
+)
+
+# runtime setting
+custom_hooks = [dict(type=EMAHook, momentum=4e-5, priority='ABOVE_NORMAL')]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+# base_batch_size = (32 GPUs) x (128 samples per GPU)
+auto_scale_lr = dict(base_batch_size=4096)
diff --git a/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py
new file mode 100644
index 0000000000000000000000000000000000000000..a254ac8a84d94acdd1ec5f84059c7e75abc3cbc4
--- /dev/null
+++ b/mmpretrain/configs/eva/eva_mae_style_vit_base_p16_16xb256_coslr_400e_in1k.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.models.mae_vit_base_p16 import *
+ from .._base_.datasets.imagenet_bs512_mae import *
+ from .._base_.default_runtime import *
+
+from mmengine.hooks import CheckpointHook
+from mmengine.optim import CosineAnnealingLR, LinearLR, OptimWrapper
+from mmengine.runner import EpochBasedTrainLoop
+from torch.optim import AdamW
+
+from mmpretrain.models import (EVA, CLIPGenerator, CosineSimilarityLoss,
+ MAEPretrainDecoder, MIMHead)
+
+# dataset settings
+train_dataloader.batch_size = 256
+
+# model settings
+model.type = EVA
+model.init_cfg = None
+model.backbone.update(init_cfg=[
+ dict(type='Xavier', distribution='uniform', layer='Linear'),
+ dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
+])
+model.neck.update(
+ type=MAEPretrainDecoder,
+ predict_feature_dim=512,
+ init_cfg=[
+ dict(type='Xavier', distribution='uniform', layer='Linear'),
+ dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
+ ])
+model.head = dict(
+ type=MIMHead,
+ loss=dict(type=CosineSimilarityLoss, shift_factor=2.0, scale_factor=2.0))
+model.target_generator = dict(
+ type=CLIPGenerator,
+ tokenizer_path= # noqa
+ 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa
+)
+
+# optimizer wrapper
+optim_wrapper = dict(
+ type=OptimWrapper,
+ optimizer=dict(
+ type=AdamW,
+ lr=1.5e-4 * 4096 / 256,
+ betas=(0.9, 0.95),
+ weight_decay=0.05),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'ln': dict(decay_mult=0.0),
+ 'bias': dict(decay_mult=0.0),
+ 'pos_embed': dict(decay_mult=0.),
+ 'mask_token': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.)
+ }))
+find_unused_parameters = True
+
+# learning rate scheduler
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-4,
+ by_epoch=True,
+ begin=0,
+ end=40,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ T_max=360,
+ by_epoch=True,
+ begin=40,
+ end=400,
+ convert_to_iter_based=True)
+]
+
+# runtime settings
+train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=400)
+default_hooks.checkpoint = dict(
+ type=CheckpointHook, interval=1, max_keep_ckpts=3)
+
+randomness.update(dict(seed=0, diff_rank_seed=True))
+
+# auto resume
+resume = True
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+auto_scale_lr = dict(base_batch_size=4096)
diff --git a/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cee3bc93fd8b1c65263ac415422b9b73628e88d
--- /dev/null
+++ b/mmpretrain/configs/mae/mae_vit_base_p16_8xb512_amp_coslr_300e_in1k.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.models.mae_vit_base_p16 import *
+ from .._base_.datasets.imagenet_bs512_mae import *
+ from .._base_.default_runtime import *
+
+from mmengine.hooks.checkpoint_hook import CheckpointHook
+from mmengine.optim.optimizer.amp_optimizer_wrapper import AmpOptimWrapper
+from mmengine.optim.scheduler.lr_scheduler import CosineAnnealingLR, LinearLR
+from mmengine.runner.loops import EpochBasedTrainLoop
+from torch.optim.adamw import AdamW
+
+# optimizer wrapper
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ loss_scale='dynamic',
+ optimizer=dict(
+ type=AdamW,
+ lr=1.5e-4 * 4096 / 256,
+ betas=(0.9, 0.95),
+ weight_decay=0.05),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'ln': dict(decay_mult=0.0),
+ 'bias': dict(decay_mult=0.0),
+ 'pos_embed': dict(decay_mult=0.),
+ 'mask_token': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.)
+ }))
+
+# learning rate scheduler
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=0.0001,
+ by_epoch=True,
+ begin=0,
+ end=40,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ T_max=260,
+ by_epoch=True,
+ begin=40,
+ end=300,
+ convert_to_iter_based=True)
+]
+
+# runtime settings
+train_cfg = dict(type=EpochBasedTrainLoop, max_epochs=300)
+# only keeps the latest 3 checkpoints
+default_hooks.checkpoint = dict(
+ type=CheckpointHook, interval=1, max_keep_ckpts=3)
+
+randomness.update(seed=0, diff_rank_seed=True)
+
+# auto resume
+resume = True
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR
+# based on the actual training batch size.
+auto_scale_lr = dict(base_batch_size=4096)
diff --git a/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16d248b6988c924e8540a7782dabee4997baba1
--- /dev/null
+++ b/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.datasets.imagenet_bs32 import *
+ from .._base_.default_runtime import *
+ from .._base_.models.resnet18 import *
+ from .._base_.schedules.imagenet_bs256 import *
diff --git a/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c738f219e561e5863dd2ce8246af005502bc83
--- /dev/null
+++ b/mmpretrain/configs/simclr/simclr_resnet50_16xb256_coslr_200e_in1k.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This is a BETA new format config file, and the usage may change recently.
+from mmengine.config import read_base
+
+with read_base():
+ from .._base_.datasets.imagenet_bs32_simclr import *
+ from .._base_.schedules.imagenet_lars_coslr_200e import *
+ from .._base_.default_runtime import *
+
+from mmengine.hooks.checkpoint_hook import CheckpointHook
+from mmengine.optim.optimizer.optimizer_wrapper import OptimWrapper
+
+from mmpretrain.engine.optimizers.lars import LARS
+from mmpretrain.models.backbones.resnet import ResNet
+from mmpretrain.models.heads.contrastive_head import ContrastiveHead
+from mmpretrain.models.losses.cross_entropy_loss import CrossEntropyLoss
+from mmpretrain.models.necks.nonlinear_neck import NonLinearNeck
+from mmpretrain.models.selfsup.simclr import SimCLR
+
+# dataset settings
+train_dataloader.merge(dict(batch_size=256))
+
+# model settings
+model = dict(
+ type=SimCLR,
+ backbone=dict(
+ type=ResNet,
+ depth=50,
+ norm_cfg=dict(type='SyncBN'),
+ zero_init_residual=True),
+ neck=dict(
+ type=NonLinearNeck, # SimCLR non-linear neck
+ in_channels=2048,
+ hid_channels=2048,
+ out_channels=128,
+ num_layers=2,
+ with_avg_pool=True),
+ head=dict(
+ type=ContrastiveHead,
+ loss=dict(type=CrossEntropyLoss),
+ temperature=0.1),
+)
+
+# optimizer
+optim_wrapper = dict(
+ type=OptimWrapper,
+ optimizer=dict(type=LARS, lr=4.8, momentum=0.9, weight_decay=1e-6),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'bn': dict(decay_mult=0, lars_exclude=True),
+ 'bias': dict(decay_mult=0, lars_exclude=True),
+ # bn layer in ResNet block downsample module
+ 'downsample.1': dict(decay_mult=0, lars_exclude=True)
+ }))
+
+# runtime settings
+default_hooks.checkpoint = dict(
+ type=CheckpointHook, interval=10, max_keep_ckpts=3)
diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b6be47dceb288acc2517a7ba2c890f2a38b671
--- /dev/null
+++ b/mmpretrain/datasets/__init__.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpretrain.utils.dependency import WITH_MULTIMODAL
+from .base_dataset import BaseDataset
+from .builder import build_dataset
+from .caltech101 import Caltech101
+from .cifar import CIFAR10, CIFAR100
+from .cub import CUB
+from .custom import CustomDataset
+from .dataset_wrappers import KFoldDataset
+from .dtd import DTD
+from .fgvcaircraft import FGVCAircraft
+from .flowers102 import Flowers102
+from .food101 import Food101
+from .imagenet import ImageNet, ImageNet21k
+from .inshop import InShop
+from .mnist import MNIST, FashionMNIST
+from .multi_label import MultiLabelDataset
+from .multi_task import MultiTaskDataset
+from .nlvr2 import NLVR2
+from .oxfordiiitpet import OxfordIIITPet
+from .places205 import Places205
+from .samplers import * # noqa: F401,F403
+from .stanfordcars import StanfordCars
+from .sun397 import SUN397
+from .transforms import * # noqa: F401,F403
+from .voc import VOC
+
+__all__ = [
+ 'BaseDataset', 'CIFAR10', 'CIFAR100', 'CUB', 'Caltech101', 'CustomDataset',
+ 'DTD', 'FGVCAircraft', 'FashionMNIST', 'Flowers102', 'Food101', 'ImageNet',
+ 'ImageNet21k', 'InShop', 'KFoldDataset', 'MNIST', 'MultiLabelDataset',
+ 'MultiTaskDataset', 'NLVR2', 'OxfordIIITPet', 'Places205', 'SUN397',
+ 'StanfordCars', 'VOC', 'build_dataset'
+]
+
+if WITH_MULTIMODAL:
+ from .coco_caption import COCOCaption
+ from .coco_retrieval import COCORetrieval
+ from .coco_vqa import COCOVQA
+ from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
+ from .flickr30k_caption import Flickr30kCaption
+ from .flickr30k_retrieval import Flickr30kRetrieval
+ from .gqa_dataset import GQA
+ from .nocaps import NoCaps
+ from .ocr_vqa import OCRVQA
+ from .refcoco import RefCOCO
+ from .scienceqa import ScienceQA
+ from .textvqa import TextVQA
+ from .visual_genome import VisualGenomeQA
+ from .vizwiz import VizWiz
+ from .vsr import VSR
+
+ __all__.extend([
+ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
+ 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
+ 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
+ 'VSR', 'VizWiz', 'OCRVQA'
+ ])
diff --git a/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86138c4e807baf0d5df8517dc7daddfd17743079
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc b/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..987fc2bbc7c5277a9d87cd9ad6cfc40f1a68b336
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/base_dataset.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc b/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49c88bf76c4080738ef2ef9983116de35e01afba
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/builder.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc b/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13b26c064f946f5a9edcbc7f98c4deb11209923e
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/caltech101.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc b/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7637efbe554357d50e85c420574d63e9f18db8f
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/categories.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc b/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdff96ed0d0cdcc97fefc21d7e6ea00a152aba2e
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/cifar.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b5c4f877657db08f61ad957666cc88250418e21
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_caption.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56d7423d9fd3fa4bda43508c0ee830c5cdc89b2c
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_retrieval.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd5f207c61d7dac58cf3c47d404355b282e61d0d
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/coco_vqa.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc b/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f896f234938963213e7797506039e46460b4b23
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/cub.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc b/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fef35c6b51eb9d211fefd68671d6bc3de1599a7
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/custom.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc b/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d39db136982d2b0b589fc4913a170d5f514f0249
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/dataset_wrappers.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc b/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dee44a6f64bbed91555b2c3fb83a1d867bf5aa60
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/dtd.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc b/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f80dd69f1597975f40557a904475510ec7df398f
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/fgvcaircraft.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f05cbfafe94af3dfa4a96d6b40e3c045e1f5d851
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flamingo.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5957d30709875a8da03bd04be3b84845941fdbe
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flickr30k_caption.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d35880f3e9ed70debe4a0e2b226fe64ca739a184
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flickr30k_retrieval.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc b/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44ac082ff98e4f43902dd42c1b2cd491ead1103a
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/flowers102.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc b/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0db483d2f9daaaf127fc3ac5bb9c1233cd41869f
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/food101.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc b/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bde140af94e82c03c8a7e93b347dcc30c0a04010
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/gqa_dataset.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc b/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aa74c0f2a722575fd09fbd77b38d7503b94ccdc
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/imagenet.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc b/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2050be3585e316e10dc8204b3ceec3ab016c0dc
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/inshop.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc b/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..755bc57048b50ce4424689b7fe1e12ef82b75777
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/mnist.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc b/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a986a8faa7d92b4cfe6319ab56c25d739274ec8c
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/multi_label.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc b/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97af6d4ca34c7850590483aafd73bdce8879a64b
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/multi_task.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc b/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..667bbe24e5ca3b43e4a6f760eaa898cf4a0114ae
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/nlvr2.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc b/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d87383515aa74ab9076ce5780b82e8340a55c40
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/nocaps.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84dc2861118e17089eeefec33001d5ccf2cad2e8
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/ocr_vqa.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc b/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6da9259d9336703e576999ca00473aa03897f986
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/oxfordiiitpet.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc b/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ed784a999d2ad479f2faf17dbeb9a834bb041ab
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/places205.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc b/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4e4425caa2e0485304f78406601159b32a0cd3d
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/refcoco.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3af13ffd3dcce838a11a86889345e3cd61b74939
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/scienceqa.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc b/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d58b147212c13fc66e8de94cf160de41c39fefdc
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/stanfordcars.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc b/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efbc882a29c11024373d72f520f1c151ddcf9cb3
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/sun397.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc b/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46679ddf891e11c4054ddeb4ed83b16a20392729
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/textvqa.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc b/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a2ce2a1cbae66c671f9e1f43dbdedd64b40e84d
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/utils.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc b/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3323fa33ff2add992f07c7b00bd0744b9f12032
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/visual_genome.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc b/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f8dd8cd38ca08d736e3dd293f78c22caf63a882
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/vizwiz.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc b/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70ab39b8464c043699f1936c5edb96cbdda946c0
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/voc.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc b/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a4eb2afdd88ae66f6bd59f35d79701d49790735
Binary files /dev/null and b/mmpretrain/datasets/__pycache__/vsr.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/base_dataset.py b/mmpretrain/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dffdf04772163b5fa55afabc8e15ac8c118aadd2
--- /dev/null
+++ b/mmpretrain/datasets/base_dataset.py
@@ -0,0 +1,219 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from os import PathLike
+from typing import List, Optional, Sequence, Union
+
+import mmengine
+import numpy as np
+from mmengine.dataset import BaseDataset as _BaseDataset
+
+from mmpretrain.registry import DATASETS, TRANSFORMS
+
+
+def expanduser(path):
+ """Expand ~ and ~user constructions.
+
+ If user or $HOME is unknown, do nothing.
+ """
+ if isinstance(path, (str, PathLike)):
+ return osp.expanduser(path)
+ else:
+ return path
+
+
+@DATASETS.register_module()
+class BaseDataset(_BaseDataset):
+ """Base dataset for image classification task.
+
+ This dataset support annotation file in `OpenMMLab 2.0 style annotation
+ format`.
+
+ .. _OpenMMLab 2.0 style annotation format:
+ https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
+
+ Comparing with the :class:`mmengine.BaseDataset`, this class implemented
+ several useful methods.
+
+ Args:
+ ann_file (str): Annotation file path.
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Defaults to None, which means using all ``data_infos``.
+ serialize_data (bool): Whether to hold memory using serialized objects,
+ when enabled, data loader workers can use shared RAM from master
+ process instead of making a copy. Defaults to True.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ test_mode (bool, optional): ``test_mode=True`` means in test phase,
+ an error will be raised when getting an item fails, ``test_mode=False``
+ means in training phase, another item will be returned randomly.
+ Defaults to False.
+ lazy_init (bool): Whether to load annotation during instantiation.
+ In some cases, such as visualization, only the meta information of
+ the dataset is needed, which is not necessary to load annotation
+ file. ``Basedataset`` can skip load annotations to save time by set
+ ``lazy_init=False``. Defaults to False.
+ max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
+ The maximum extra number of cycles to get a valid image.
+ Defaults to 1000.
+ classes (str | Sequence[str], optional): Specify names of classes.
+
+ - If is string, it should be a file path, and the every line of
+ the file is a name of a class.
+ - If is a sequence of string, every item is a name of class.
+ - If is None, use categories information in ``metainfo`` argument,
+ annotation file or the class attribute ``METAINFO``.
+
+ Defaults to None.
+ """ # noqa: E501
+
+ def __init__(self,
+ ann_file: str,
+ metainfo: Optional[dict] = None,
+ data_root: str = '',
+ data_prefix: Union[str, dict] = '',
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: Sequence = (),
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000,
+ classes: Union[str, Sequence[str], None] = None):
+ if isinstance(data_prefix, str):
+ data_prefix = dict(img_path=expanduser(data_prefix))
+
+ ann_file = expanduser(ann_file)
+ metainfo = self._compat_classes(metainfo, classes)
+
+ transforms = []
+ for transform in pipeline:
+ if isinstance(transform, dict):
+ transforms.append(TRANSFORMS.build(transform))
+ else:
+ transforms.append(transform)
+
+ super().__init__(
+ ann_file=ann_file,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ filter_cfg=filter_cfg,
+ indices=indices,
+ serialize_data=serialize_data,
+ pipeline=transforms,
+ test_mode=test_mode,
+ lazy_init=lazy_init,
+ max_refetch=max_refetch)
+
+ @property
+ def img_prefix(self):
+ """The prefix of images."""
+ return self.data_prefix['img_path']
+
+ @property
+ def CLASSES(self):
+ """Return all categories names."""
+ return self._metainfo.get('classes', None)
+
+ @property
+ def class_to_idx(self):
+ """Map mapping class name to class index.
+
+ Returns:
+ dict: mapping from class name to class index.
+ """
+
+ return {cat: i for i, cat in enumerate(self.CLASSES)}
+
+ def get_gt_labels(self):
+ """Get all ground-truth labels (categories).
+
+ Returns:
+ np.ndarray: categories for all images.
+ """
+
+ gt_labels = np.array(
+ [self.get_data_info(i)['gt_label'] for i in range(len(self))])
+ return gt_labels
+
+ def get_cat_ids(self, idx: int) -> List[int]:
+ """Get category id by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ cat_ids (List[int]): Image category of specified index.
+ """
+
+ return [int(self.get_data_info(idx)['gt_label'])]
+
+ def _compat_classes(self, metainfo, classes):
+ """Merge the old style ``classes`` arguments to ``metainfo``."""
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmengine.list_from_file(expanduser(classes))
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ elif classes is not None:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ if metainfo is None:
+ metainfo = {}
+
+ if classes is not None:
+ metainfo = {'classes': tuple(class_names), **metainfo}
+
+ return metainfo
+
+ def full_init(self):
+ """Load annotation file and set ``BaseDataset._fully_initialized`` to
+ True."""
+ super().full_init()
+
+ # To support the standard OpenMMLab 2.0 annotation format. Generate
+ # metainfo in internal format from standard metainfo format.
+ if 'categories' in self._metainfo and 'classes' not in self._metainfo:
+ categories = sorted(
+ self._metainfo['categories'], key=lambda x: x['id'])
+ self._metainfo['classes'] = tuple(
+ [cat['category_name'] for cat in categories])
+
+ def __repr__(self):
+ """Print the basic information of the dataset.
+
+ Returns:
+ str: Formatted string.
+ """
+ head = 'Dataset ' + self.__class__.__name__
+ body = []
+ if self._fully_initialized:
+ body.append(f'Number of samples: \t{self.__len__()}')
+ else:
+ body.append("Haven't been initialized")
+
+ if self.CLASSES is not None:
+ body.append(f'Number of categories: \t{len(self.CLASSES)}')
+
+ body.extend(self.extra_repr())
+
+ if len(self.pipeline.transforms) > 0:
+ body.append('With transforms:')
+ for t in self.pipeline.transforms:
+ body.append(f' {t}')
+
+ lines = [head] + [' ' * 4 + line for line in body]
+ return '\n'.join(lines)
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = []
+ body.append(f'Annotation file: \t{self.ann_file}')
+ body.append(f'Prefix of images: \t{self.img_prefix}')
+ return body
diff --git a/mmpretrain/datasets/builder.py b/mmpretrain/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa3872fe9931a4946368f07dfc5f5913a3e1f9f
--- /dev/null
+++ b/mmpretrain/datasets/builder.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmpretrain.registry import DATASETS
+
+
+def build_dataset(cfg):
+ """Build dataset.
+
+ Examples:
+ >>> from mmpretrain.datasets import build_dataset
+ >>> mnist_train = build_dataset(
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
+ >>> print(mnist_train)
+ Dataset MNIST
+ Number of samples: 60000
+ Number of categories: 10
+ Prefix of data: data/mnist/
+ >>> mnist_test = build_dataset(
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True))
+ >>> print(mnist_test)
+ Dataset MNIST
+ Number of samples: 10000
+ Number of categories: 10
+ Prefix of data: data/mnist/
+ """
+ return DATASETS.build(cfg)
diff --git a/mmpretrain/datasets/caltech101.py b/mmpretrain/datasets/caltech101.py
new file mode 100644
index 0000000000000000000000000000000000000000..71e5de85ff3bbf73c387a071f47113b46be36e2a
--- /dev/null
+++ b/mmpretrain/datasets/caltech101.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import CALTECH101_CATEGORIES
+
+
+@DATASETS.register_module()
+class Caltech101(BaseDataset):
+ """The Caltech101 Dataset.
+
+ Support the `Caltech101 `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ Caltech101 dataset directory: ::
+
+ caltech-101
+ ├── 101_ObjectCategories
+ │ ├── class_x
+ │ │ ├── xx1.jpg
+ │ │ ├── xx2.jpg
+ │ │ └── ...
+ │ ├── class_y
+ │ │ ├── yy1.jpg
+ │ │ ├── yy2.jpg
+ │ │ └── ...
+ │ └── ...
+ ├── Annotations
+ │ ├── class_x
+ │ │ ├── xx1.mat
+ │ │ └── ...
+ │ └── ...
+ ├── meta
+ │ ├── train.txt
+ │ └── test.txt
+ └── ....
+
+ Please note that since there is no official splitting for training and
+ test set, you can use the train.txt and text.txt provided by us or
+ create your own annotation files. Here is the download
+ `link `_
+ for the annotations.
+
+ Args:
+ data_root (str): The root directory for the Caltech101 dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+
+ Examples:
+ >>> from mmpretrain.datasets import Caltech101
+ >>> train_dataset = Caltech101(data_root='data/caltech-101', split='train')
+ >>> train_dataset
+ Dataset Caltech101
+ Number of samples: 3060
+ Number of categories: 102
+ Root of dataset: data/caltech-101
+ >>> test_dataset = Caltech101(data_root='data/caltech-101', split='test')
+ >>> test_dataset
+ Dataset Caltech101
+ Number of samples: 6728
+ Number of categories: 102
+ Root of dataset: data/caltech-101
+ """ # noqa: E501
+
+ METAINFO = {'classes': CALTECH101_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+
+ if split == 'train':
+ ann_file = self.backend.join_path('meta', 'train.txt')
+ else:
+ ann_file = self.backend.join_path('meta', 'test.txt')
+
+ data_prefix = '101_ObjectCategories'
+ test_mode = split == 'test'
+
+ super(Caltech101, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ pairs = list_from_file(self.ann_file)
+ data_list = []
+
+ for pair in pairs:
+ path, gt_label = pair.split()
+ img_path = self.backend.join_path(self.img_prefix, path)
+ info = dict(img_path=img_path, gt_label=int(gt_label))
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py
new file mode 100644
index 0000000000000000000000000000000000000000..011ee5c1609ee01614c485abfa69cf0d4fc35417
--- /dev/null
+++ b/mmpretrain/datasets/categories.py
@@ -0,0 +1,1440 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Pre-defined categories names of various datasets.
+
+VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
+ 'sofa', 'train', 'tvmonitor')
+
+CUB_CATEGORIES = (
+ 'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
+ 'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet', 'Parakeet_Auklet',
+ 'Rhinoceros_Auklet', 'Brewer_Blackbird', 'Red_winged_Blackbird',
+ 'Rusty_Blackbird', 'Yellow_headed_Blackbird', 'Bobolink', 'Indigo_Bunting',
+ 'Lazuli_Bunting', 'Painted_Bunting', 'Cardinal', 'Spotted_Catbird',
+ 'Gray_Catbird', 'Yellow_breasted_Chat', 'Eastern_Towhee',
+ 'Chuck_will_Widow', 'Brandt_Cormorant', 'Red_faced_Cormorant',
+ 'Pelagic_Cormorant', 'Bronzed_Cowbird', 'Shiny_Cowbird', 'Brown_Creeper',
+ 'American_Crow', 'Fish_Crow', 'Black_billed_Cuckoo', 'Mangrove_Cuckoo',
+ 'Yellow_billed_Cuckoo', 'Gray_crowned_Rosy_Finch', 'Purple_Finch',
+ 'Northern_Flicker', 'Acadian_Flycatcher', 'Great_Crested_Flycatcher',
+ 'Least_Flycatcher', 'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
+ 'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
+ 'Northern_Fulmar', 'Gadwall', 'American_Goldfinch', 'European_Goldfinch',
+ 'Boat_tailed_Grackle', 'Eared_Grebe', 'Horned_Grebe', 'Pied_billed_Grebe',
+ 'Western_Grebe', 'Blue_Grosbeak', 'Evening_Grosbeak', 'Pine_Grosbeak',
+ 'Rose_breasted_Grosbeak', 'Pigeon_Guillemot', 'California_Gull',
+ 'Glaucous_winged_Gull', 'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull',
+ 'Ring_billed_Gull', 'Slaty_backed_Gull', 'Western_Gull',
+ 'Anna_Hummingbird', 'Ruby_throated_Hummingbird', 'Rufous_Hummingbird',
+ 'Green_Violetear', 'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay',
+ 'Florida_Jay', 'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird',
+ 'Gray_Kingbird', 'Belted_Kingfisher', 'Green_Kingfisher',
+ 'Pied_Kingfisher', 'Ringed_Kingfisher', 'White_breasted_Kingfisher',
+ 'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
+ 'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
+ 'Mockingbird', 'Nighthawk', 'Clark_Nutcracker', 'White_breasted_Nuthatch',
+ 'Baltimore_Oriole', 'Hooded_Oriole', 'Orchard_Oriole', 'Scott_Oriole',
+ 'Ovenbird', 'Brown_Pelican', 'White_Pelican', 'Western_Wood_Pewee',
+ 'Sayornis', 'American_Pipit', 'Whip_poor_Will', 'Horned_Puffin',
+ 'Common_Raven', 'White_necked_Raven', 'American_Redstart', 'Geococcyx',
+ 'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
+ 'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
+ 'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow', 'Fox_Sparrow',
+ 'Grasshopper_Sparrow', 'Harris_Sparrow', 'Henslow_Sparrow',
+ 'Le_Conte_Sparrow', 'Lincoln_Sparrow', 'Nelson_Sharp_tailed_Sparrow',
+ 'Savannah_Sparrow', 'Seaside_Sparrow', 'Song_Sparrow', 'Tree_Sparrow',
+ 'Vesper_Sparrow', 'White_crowned_Sparrow', 'White_throated_Sparrow',
+ 'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow', 'Cliff_Swallow',
+ 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager', 'Artic_Tern',
+ 'Black_Tern', 'Caspian_Tern', 'Common_Tern', 'Elegant_Tern',
+ 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee', 'Brown_Thrasher',
+ 'Sage_Thrasher', 'Black_capped_Vireo', 'Blue_headed_Vireo',
+ 'Philadelphia_Vireo', 'Red_eyed_Vireo', 'Warbling_Vireo',
+ 'White_eyed_Vireo', 'Yellow_throated_Vireo', 'Bay_breasted_Warbler',
+ 'Black_and_white_Warbler', 'Black_throated_Blue_Warbler',
+ 'Blue_winged_Warbler', 'Canada_Warbler', 'Cape_May_Warbler',
+ 'Cerulean_Warbler', 'Chestnut_sided_Warbler', 'Golden_winged_Warbler',
+ 'Hooded_Warbler', 'Kentucky_Warbler', 'Magnolia_Warbler',
+ 'Mourning_Warbler', 'Myrtle_Warbler', 'Nashville_Warbler',
+ 'Orange_crowned_Warbler', 'Palm_Warbler', 'Pine_Warbler',
+ 'Prairie_Warbler', 'Prothonotary_Warbler', 'Swainson_Warbler',
+ 'Tennessee_Warbler', 'Wilson_Warbler', 'Worm_eating_Warbler',
+ 'Yellow_Warbler', 'Northern_Waterthrush', 'Louisiana_Waterthrush',
+ 'Bohemian_Waxwing', 'Cedar_Waxwing', 'American_Three_toed_Woodpecker',
+ 'Pileated_Woodpecker', 'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
+ 'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren', 'Cactus_Wren',
+ 'Carolina_Wren', 'House_Wren', 'Marsh_Wren', 'Rock_Wren', 'Winter_Wren',
+ 'Common_Yellowthroat')
+
+IMAGENET_CATEGORIES = (
+ 'tench, Tinca tinca',
+ 'goldfish, Carassius auratus',
+ 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501
+ 'tiger shark, Galeocerdo cuvieri',
+ 'hammerhead, hammerhead shark',
+ 'electric ray, crampfish, numbfish, torpedo',
+ 'stingray',
+ 'cock',
+ 'hen',
+ 'ostrich, Struthio camelus',
+ 'brambling, Fringilla montifringilla',
+ 'goldfinch, Carduelis carduelis',
+ 'house finch, linnet, Carpodacus mexicanus',
+ 'junco, snowbird',
+ 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 'robin, American robin, Turdus migratorius',
+ 'bulbul',
+ 'jay',
+ 'magpie',
+ 'chickadee',
+ 'water ouzel, dipper',
+ 'kite',
+ 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 'vulture',
+ 'great grey owl, great gray owl, Strix nebulosa',
+ 'European fire salamander, Salamandra salamandra',
+ 'common newt, Triturus vulgaris',
+ 'eft',
+ 'spotted salamander, Ambystoma maculatum',
+ 'axolotl, mud puppy, Ambystoma mexicanum',
+ 'bullfrog, Rana catesbeiana',
+ 'tree frog, tree-frog',
+ 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 'loggerhead, loggerhead turtle, Caretta caretta',
+ 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501
+ 'mud turtle',
+ 'terrapin',
+ 'box turtle, box tortoise',
+ 'banded gecko',
+ 'common iguana, iguana, Iguana iguana',
+ 'American chameleon, anole, Anolis carolinensis',
+ 'whiptail, whiptail lizard',
+ 'agama',
+ 'frilled lizard, Chlamydosaurus kingi',
+ 'alligator lizard',
+ 'Gila monster, Heloderma suspectum',
+ 'green lizard, Lacerta viridis',
+ 'African chameleon, Chamaeleo chamaeleon',
+ 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501
+ 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 'American alligator, Alligator mississipiensis',
+ 'triceratops',
+ 'thunder snake, worm snake, Carphophis amoenus',
+ 'ringneck snake, ring-necked snake, ring snake',
+ 'hognose snake, puff adder, sand viper',
+ 'green snake, grass snake',
+ 'king snake, kingsnake',
+ 'garter snake, grass snake',
+ 'water snake',
+ 'vine snake',
+ 'night snake, Hypsiglena torquata',
+ 'boa constrictor, Constrictor constrictor',
+ 'rock python, rock snake, Python sebae',
+ 'Indian cobra, Naja naja',
+ 'green mamba',
+ 'sea snake',
+ 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 'trilobite',
+ 'harvestman, daddy longlegs, Phalangium opilio',
+ 'scorpion',
+ 'black and gold garden spider, Argiope aurantia',
+ 'barn spider, Araneus cavaticus',
+ 'garden spider, Aranea diademata',
+ 'black widow, Latrodectus mactans',
+ 'tarantula',
+ 'wolf spider, hunting spider',
+ 'tick',
+ 'centipede',
+ 'black grouse',
+ 'ptarmigan',
+ 'ruffed grouse, partridge, Bonasa umbellus',
+ 'prairie chicken, prairie grouse, prairie fowl',
+ 'peacock',
+ 'quail',
+ 'partridge',
+ 'African grey, African gray, Psittacus erithacus',
+ 'macaw',
+ 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 'lorikeet',
+ 'coucal',
+ 'bee eater',
+ 'hornbill',
+ 'hummingbird',
+ 'jacamar',
+ 'toucan',
+ 'drake',
+ 'red-breasted merganser, Mergus serrator',
+ 'goose',
+ 'black swan, Cygnus atratus',
+ 'tusker',
+ 'echidna, spiny anteater, anteater',
+ 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501
+ 'wallaby, brush kangaroo',
+ 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501
+ 'wombat',
+ 'jellyfish',
+ 'sea anemone, anemone',
+ 'brain coral',
+ 'flatworm, platyhelminth',
+ 'nematode, nematode worm, roundworm',
+ 'conch',
+ 'snail',
+ 'slug',
+ 'sea slug, nudibranch',
+ 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 'chambered nautilus, pearly nautilus, nautilus',
+ 'Dungeness crab, Cancer magister',
+ 'rock crab, Cancer irroratus',
+ 'fiddler crab',
+ 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501
+ 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501
+ 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501
+ 'crayfish, crawfish, crawdad, crawdaddy',
+ 'hermit crab',
+ 'isopod',
+ 'white stork, Ciconia ciconia',
+ 'black stork, Ciconia nigra',
+ 'spoonbill',
+ 'flamingo',
+ 'little blue heron, Egretta caerulea',
+ 'American egret, great white heron, Egretta albus',
+ 'bittern',
+ 'crane',
+ 'limpkin, Aramus pictus',
+ 'European gallinule, Porphyrio porphyrio',
+ 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 'bustard',
+ 'ruddy turnstone, Arenaria interpres',
+ 'red-backed sandpiper, dunlin, Erolia alpina',
+ 'redshank, Tringa totanus',
+ 'dowitcher',
+ 'oystercatcher, oyster catcher',
+ 'pelican',
+ 'king penguin, Aptenodytes patagonica',
+ 'albatross, mollymawk',
+ 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501
+ 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 'dugong, Dugong dugon',
+ 'sea lion',
+ 'Chihuahua',
+ 'Japanese spaniel',
+ 'Maltese dog, Maltese terrier, Maltese',
+ 'Pekinese, Pekingese, Peke',
+ 'Shih-Tzu',
+ 'Blenheim spaniel',
+ 'papillon',
+ 'toy terrier',
+ 'Rhodesian ridgeback',
+ 'Afghan hound, Afghan',
+ 'basset, basset hound',
+ 'beagle',
+ 'bloodhound, sleuthhound',
+ 'bluetick',
+ 'black-and-tan coonhound',
+ 'Walker hound, Walker foxhound',
+ 'English foxhound',
+ 'redbone',
+ 'borzoi, Russian wolfhound',
+ 'Irish wolfhound',
+ 'Italian greyhound',
+ 'whippet',
+ 'Ibizan hound, Ibizan Podenco',
+ 'Norwegian elkhound, elkhound',
+ 'otterhound, otter hound',
+ 'Saluki, gazelle hound',
+ 'Scottish deerhound, deerhound',
+ 'Weimaraner',
+ 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501
+ 'Bedlington terrier',
+ 'Border terrier',
+ 'Kerry blue terrier',
+ 'Irish terrier',
+ 'Norfolk terrier',
+ 'Norwich terrier',
+ 'Yorkshire terrier',
+ 'wire-haired fox terrier',
+ 'Lakeland terrier',
+ 'Sealyham terrier, Sealyham',
+ 'Airedale, Airedale terrier',
+ 'cairn, cairn terrier',
+ 'Australian terrier',
+ 'Dandie Dinmont, Dandie Dinmont terrier',
+ 'Boston bull, Boston terrier',
+ 'miniature schnauzer',
+ 'giant schnauzer',
+ 'standard schnauzer',
+ 'Scotch terrier, Scottish terrier, Scottie',
+ 'Tibetan terrier, chrysanthemum dog',
+ 'silky terrier, Sydney silky',
+ 'soft-coated wheaten terrier',
+ 'West Highland white terrier',
+ 'Lhasa, Lhasa apso',
+ 'flat-coated retriever',
+ 'curly-coated retriever',
+ 'golden retriever',
+ 'Labrador retriever',
+ 'Chesapeake Bay retriever',
+ 'German short-haired pointer',
+ 'vizsla, Hungarian pointer',
+ 'English setter',
+ 'Irish setter, red setter',
+ 'Gordon setter',
+ 'Brittany spaniel',
+ 'clumber, clumber spaniel',
+ 'English springer, English springer spaniel',
+ 'Welsh springer spaniel',
+ 'cocker spaniel, English cocker spaniel, cocker',
+ 'Sussex spaniel',
+ 'Irish water spaniel',
+ 'kuvasz',
+ 'schipperke',
+ 'groenendael',
+ 'malinois',
+ 'briard',
+ 'kelpie',
+ 'komondor',
+ 'Old English sheepdog, bobtail',
+ 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 'collie',
+ 'Border collie',
+ 'Bouvier des Flandres, Bouviers des Flandres',
+ 'Rottweiler',
+ 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 'Doberman, Doberman pinscher',
+ 'miniature pinscher',
+ 'Greater Swiss Mountain dog',
+ 'Bernese mountain dog',
+ 'Appenzeller',
+ 'EntleBucher',
+ 'boxer',
+ 'bull mastiff',
+ 'Tibetan mastiff',
+ 'French bulldog',
+ 'Great Dane',
+ 'Saint Bernard, St Bernard',
+ 'Eskimo dog, husky',
+ 'malamute, malemute, Alaskan malamute',
+ 'Siberian husky',
+ 'dalmatian, coach dog, carriage dog',
+ 'affenpinscher, monkey pinscher, monkey dog',
+ 'basenji',
+ 'pug, pug-dog',
+ 'Leonberg',
+ 'Newfoundland, Newfoundland dog',
+ 'Great Pyrenees',
+ 'Samoyed, Samoyede',
+ 'Pomeranian',
+ 'chow, chow chow',
+ 'keeshond',
+ 'Brabancon griffon',
+ 'Pembroke, Pembroke Welsh corgi',
+ 'Cardigan, Cardigan Welsh corgi',
+ 'toy poodle',
+ 'miniature poodle',
+ 'standard poodle',
+ 'Mexican hairless',
+ 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 'dingo, warrigal, warragal, Canis dingo',
+ 'dhole, Cuon alpinus',
+ 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 'hyena, hyaena',
+ 'red fox, Vulpes vulpes',
+ 'kit fox, Vulpes macrotis',
+ 'Arctic fox, white fox, Alopex lagopus',
+ 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 'tabby, tabby cat',
+ 'tiger cat',
+ 'Persian cat',
+ 'Siamese cat, Siamese',
+ 'Egyptian cat',
+ 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501
+ 'lynx, catamount',
+ 'leopard, Panthera pardus',
+ 'snow leopard, ounce, Panthera uncia',
+ 'jaguar, panther, Panthera onca, Felis onca',
+ 'lion, king of beasts, Panthera leo',
+ 'tiger, Panthera tigris',
+ 'cheetah, chetah, Acinonyx jubatus',
+ 'brown bear, bruin, Ursus arctos',
+ 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501
+ 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 'mongoose',
+ 'meerkat, mierkat',
+ 'tiger beetle',
+ 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 'ground beetle, carabid beetle',
+ 'long-horned beetle, longicorn, longicorn beetle',
+ 'leaf beetle, chrysomelid',
+ 'dung beetle',
+ 'rhinoceros beetle',
+ 'weevil',
+ 'fly',
+ 'bee',
+ 'ant, emmet, pismire',
+ 'grasshopper, hopper',
+ 'cricket',
+ 'walking stick, walkingstick, stick insect',
+ 'cockroach, roach',
+ 'mantis, mantid',
+ 'cicada, cicala',
+ 'leafhopper',
+ 'lacewing, lacewing fly',
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501
+ 'damselfly',
+ 'admiral',
+ 'ringlet, ringlet butterfly',
+ 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 'cabbage butterfly',
+ 'sulphur butterfly, sulfur butterfly',
+ 'lycaenid, lycaenid butterfly',
+ 'starfish, sea star',
+ 'sea urchin',
+ 'sea cucumber, holothurian',
+ 'wood rabbit, cottontail, cottontail rabbit',
+ 'hare',
+ 'Angora, Angora rabbit',
+ 'hamster',
+ 'porcupine, hedgehog',
+ 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 'marmot',
+ 'beaver',
+ 'guinea pig, Cavia cobaya',
+ 'sorrel',
+ 'zebra',
+ 'hog, pig, grunter, squealer, Sus scrofa',
+ 'wild boar, boar, Sus scrofa',
+ 'warthog',
+ 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 'ox',
+ 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 'bison',
+ 'ram, tup',
+ 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501
+ 'ibex, Capra ibex',
+ 'hartebeest',
+ 'impala, Aepyceros melampus',
+ 'gazelle',
+ 'Arabian camel, dromedary, Camelus dromedarius',
+ 'llama',
+ 'weasel',
+ 'mink',
+ 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 'black-footed ferret, ferret, Mustela nigripes',
+ 'otter',
+ 'skunk, polecat, wood pussy',
+ 'badger',
+ 'armadillo',
+ 'three-toed sloth, ai, Bradypus tridactylus',
+ 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 'gorilla, Gorilla gorilla',
+ 'chimpanzee, chimp, Pan troglodytes',
+ 'gibbon, Hylobates lar',
+ 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 'guenon, guenon monkey',
+ 'patas, hussar monkey, Erythrocebus patas',
+ 'baboon',
+ 'macaque',
+ 'langur',
+ 'colobus, colobus monkey',
+ 'proboscis monkey, Nasalis larvatus',
+ 'marmoset',
+ 'capuchin, ringtail, Cebus capucinus',
+ 'howler monkey, howler',
+ 'titi, titi monkey',
+ 'spider monkey, Ateles geoffroyi',
+ 'squirrel monkey, Saimiri sciureus',
+ 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 'indri, indris, Indri indri, Indri brevicaudatus',
+ 'Indian elephant, Elephas maximus',
+ 'African elephant, Loxodonta africana',
+ 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 'barracouta, snoek',
+ 'eel',
+ 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501
+ 'rock beauty, Holocanthus tricolor',
+ 'anemone fish',
+ 'sturgeon',
+ 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 'lionfish',
+ 'puffer, pufferfish, blowfish, globefish',
+ 'abacus',
+ 'abaya',
+ "academic gown, academic robe, judge's robe",
+ 'accordion, piano accordion, squeeze box',
+ 'acoustic guitar',
+ 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 'airliner',
+ 'airship, dirigible',
+ 'altar',
+ 'ambulance',
+ 'amphibian, amphibious vehicle',
+ 'analog clock',
+ 'apiary, bee house',
+ 'apron',
+ 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501
+ 'assault rifle, assault gun',
+ 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 'bakery, bakeshop, bakehouse',
+ 'balance beam, beam',
+ 'balloon',
+ 'ballpoint, ballpoint pen, ballpen, Biro',
+ 'Band Aid',
+ 'banjo',
+ 'bannister, banister, balustrade, balusters, handrail',
+ 'barbell',
+ 'barber chair',
+ 'barbershop',
+ 'barn',
+ 'barometer',
+ 'barrel, cask',
+ 'barrow, garden cart, lawn cart, wheelbarrow',
+ 'baseball',
+ 'basketball',
+ 'bassinet',
+ 'bassoon',
+ 'bathing cap, swimming cap',
+ 'bath towel',
+ 'bathtub, bathing tub, bath, tub',
+ 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501
+ 'beacon, lighthouse, beacon light, pharos',
+ 'beaker',
+ 'bearskin, busby, shako',
+ 'beer bottle',
+ 'beer glass',
+ 'bell cote, bell cot',
+ 'bib',
+ 'bicycle-built-for-two, tandem bicycle, tandem',
+ 'bikini, two-piece',
+ 'binder, ring-binder',
+ 'binoculars, field glasses, opera glasses',
+ 'birdhouse',
+ 'boathouse',
+ 'bobsled, bobsleigh, bob',
+ 'bolo tie, bolo, bola tie, bola',
+ 'bonnet, poke bonnet',
+ 'bookcase',
+ 'bookshop, bookstore, bookstall',
+ 'bottlecap',
+ 'bow',
+ 'bow tie, bow-tie, bowtie',
+ 'brass, memorial tablet, plaque',
+ 'brassiere, bra, bandeau',
+ 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 'breastplate, aegis, egis',
+ 'broom',
+ 'bucket, pail',
+ 'buckle',
+ 'bulletproof vest',
+ 'bullet train, bullet',
+ 'butcher shop, meat market',
+ 'cab, hack, taxi, taxicab',
+ 'caldron, cauldron',
+ 'candle, taper, wax light',
+ 'cannon',
+ 'canoe',
+ 'can opener, tin opener',
+ 'cardigan',
+ 'car mirror',
+ 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ "carpenter's kit, tool kit",
+ 'carton',
+ 'car wheel',
+ 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501
+ 'cassette',
+ 'cassette player',
+ 'castle',
+ 'catamaran',
+ 'CD player',
+ 'cello, violoncello',
+ 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 'chain',
+ 'chainlink fence',
+ 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501
+ 'chain saw, chainsaw',
+ 'chest',
+ 'chiffonier, commode',
+ 'chime, bell, gong',
+ 'china cabinet, china closet',
+ 'Christmas stocking',
+ 'church, church building',
+ 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 'cleaver, meat cleaver, chopper',
+ 'cliff dwelling',
+ 'cloak',
+ 'clog, geta, patten, sabot',
+ 'cocktail shaker',
+ 'coffee mug',
+ 'coffeepot',
+ 'coil, spiral, volute, whorl, helix',
+ 'combination lock',
+ 'computer keyboard, keypad',
+ 'confectionery, confectionary, candy store',
+ 'container ship, containership, container vessel',
+ 'convertible',
+ 'corkscrew, bottle screw',
+ 'cornet, horn, trumpet, trump',
+ 'cowboy boot',
+ 'cowboy hat, ten-gallon hat',
+ 'cradle',
+ 'crane',
+ 'crash helmet',
+ 'crate',
+ 'crib, cot',
+ 'Crock Pot',
+ 'croquet ball',
+ 'crutch',
+ 'cuirass',
+ 'dam, dike, dyke',
+ 'desk',
+ 'desktop computer',
+ 'dial telephone, dial phone',
+ 'diaper, nappy, napkin',
+ 'digital clock',
+ 'digital watch',
+ 'dining table, board',
+ 'dishrag, dishcloth',
+ 'dishwasher, dish washer, dishwashing machine',
+ 'disk brake, disc brake',
+ 'dock, dockage, docking facility',
+ 'dogsled, dog sled, dog sleigh',
+ 'dome',
+ 'doormat, welcome mat',
+ 'drilling platform, offshore rig',
+ 'drum, membranophone, tympan',
+ 'drumstick',
+ 'dumbbell',
+ 'Dutch oven',
+ 'electric fan, blower',
+ 'electric guitar',
+ 'electric locomotive',
+ 'entertainment center',
+ 'envelope',
+ 'espresso maker',
+ 'face powder',
+ 'feather boa, boa',
+ 'file, file cabinet, filing cabinet',
+ 'fireboat',
+ 'fire engine, fire truck',
+ 'fire screen, fireguard',
+ 'flagpole, flagstaff',
+ 'flute, transverse flute',
+ 'folding chair',
+ 'football helmet',
+ 'forklift',
+ 'fountain',
+ 'fountain pen',
+ 'four-poster',
+ 'freight car',
+ 'French horn, horn',
+ 'frying pan, frypan, skillet',
+ 'fur coat',
+ 'garbage truck, dustcart',
+ 'gasmask, respirator, gas helmet',
+ 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 'goblet',
+ 'go-kart',
+ 'golf ball',
+ 'golfcart, golf cart',
+ 'gondola',
+ 'gong, tam-tam',
+ 'gown',
+ 'grand piano, grand',
+ 'greenhouse, nursery, glasshouse',
+ 'grille, radiator grille',
+ 'grocery store, grocery, food market, market',
+ 'guillotine',
+ 'hair slide',
+ 'hair spray',
+ 'half track',
+ 'hammer',
+ 'hamper',
+ 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 'hand-held computer, hand-held microcomputer',
+ 'handkerchief, hankie, hanky, hankey',
+ 'hard disc, hard disk, fixed disk',
+ 'harmonica, mouth organ, harp, mouth harp',
+ 'harp',
+ 'harvester, reaper',
+ 'hatchet',
+ 'holster',
+ 'home theater, home theatre',
+ 'honeycomb',
+ 'hook, claw',
+ 'hoopskirt, crinoline',
+ 'horizontal bar, high bar',
+ 'horse cart, horse-cart',
+ 'hourglass',
+ 'iPod',
+ 'iron, smoothing iron',
+ "jack-o'-lantern",
+ 'jean, blue jean, denim',
+ 'jeep, landrover',
+ 'jersey, T-shirt, tee shirt',
+ 'jigsaw puzzle',
+ 'jinrikisha, ricksha, rickshaw',
+ 'joystick',
+ 'kimono',
+ 'knee pad',
+ 'knot',
+ 'lab coat, laboratory coat',
+ 'ladle',
+ 'lampshade, lamp shade',
+ 'laptop, laptop computer',
+ 'lawn mower, mower',
+ 'lens cap, lens cover',
+ 'letter opener, paper knife, paperknife',
+ 'library',
+ 'lifeboat',
+ 'lighter, light, igniter, ignitor',
+ 'limousine, limo',
+ 'liner, ocean liner',
+ 'lipstick, lip rouge',
+ 'Loafer',
+ 'lotion',
+ 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501
+ "loupe, jeweler's loupe",
+ 'lumbermill, sawmill',
+ 'magnetic compass',
+ 'mailbag, postbag',
+ 'mailbox, letter box',
+ 'maillot',
+ 'maillot, tank suit',
+ 'manhole cover',
+ 'maraca',
+ 'marimba, xylophone',
+ 'mask',
+ 'matchstick',
+ 'maypole',
+ 'maze, labyrinth',
+ 'measuring cup',
+ 'medicine chest, medicine cabinet',
+ 'megalith, megalithic structure',
+ 'microphone, mike',
+ 'microwave, microwave oven',
+ 'military uniform',
+ 'milk can',
+ 'minibus',
+ 'miniskirt, mini',
+ 'minivan',
+ 'missile',
+ 'mitten',
+ 'mixing bowl',
+ 'mobile home, manufactured home',
+ 'Model T',
+ 'modem',
+ 'monastery',
+ 'monitor',
+ 'moped',
+ 'mortar',
+ 'mortarboard',
+ 'mosque',
+ 'mosquito net',
+ 'motor scooter, scooter',
+ 'mountain bike, all-terrain bike, off-roader',
+ 'mountain tent',
+ 'mouse, computer mouse',
+ 'mousetrap',
+ 'moving van',
+ 'muzzle',
+ 'nail',
+ 'neck brace',
+ 'necklace',
+ 'nipple',
+ 'notebook, notebook computer',
+ 'obelisk',
+ 'oboe, hautboy, hautbois',
+ 'ocarina, sweet potato',
+ 'odometer, hodometer, mileometer, milometer',
+ 'oil filter',
+ 'organ, pipe organ',
+ 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 'overskirt',
+ 'oxcart',
+ 'oxygen mask',
+ 'packet',
+ 'paddle, boat paddle',
+ 'paddlewheel, paddle wheel',
+ 'padlock',
+ 'paintbrush',
+ "pajama, pyjama, pj's, jammies",
+ 'palace',
+ 'panpipe, pandean pipe, syrinx',
+ 'paper towel',
+ 'parachute, chute',
+ 'parallel bars, bars',
+ 'park bench',
+ 'parking meter',
+ 'passenger car, coach, carriage',
+ 'patio, terrace',
+ 'pay-phone, pay-station',
+ 'pedestal, plinth, footstall',
+ 'pencil box, pencil case',
+ 'pencil sharpener',
+ 'perfume, essence',
+ 'Petri dish',
+ 'photocopier',
+ 'pick, plectrum, plectron',
+ 'pickelhaube',
+ 'picket fence, paling',
+ 'pickup, pickup truck',
+ 'pier',
+ 'piggy bank, penny bank',
+ 'pill bottle',
+ 'pillow',
+ 'ping-pong ball',
+ 'pinwheel',
+ 'pirate, pirate ship',
+ 'pitcher, ewer',
+ "plane, carpenter's plane, woodworking plane",
+ 'planetarium',
+ 'plastic bag',
+ 'plate rack',
+ 'plow, plough',
+ "plunger, plumber's helper",
+ 'Polaroid camera, Polaroid Land camera',
+ 'pole',
+ 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501
+ 'poncho',
+ 'pool table, billiard table, snooker table',
+ 'pop bottle, soda bottle',
+ 'pot, flowerpot',
+ "potter's wheel",
+ 'power drill',
+ 'prayer rug, prayer mat',
+ 'printer',
+ 'prison, prison house',
+ 'projectile, missile',
+ 'projector',
+ 'puck, hockey puck',
+ 'punching bag, punch bag, punching ball, punchball',
+ 'purse',
+ 'quill, quill pen',
+ 'quilt, comforter, comfort, puff',
+ 'racer, race car, racing car',
+ 'racket, racquet',
+ 'radiator',
+ 'radio, wireless',
+ 'radio telescope, radio reflector',
+ 'rain barrel',
+ 'recreational vehicle, RV, R.V.',
+ 'reel',
+ 'reflex camera',
+ 'refrigerator, icebox',
+ 'remote control, remote',
+ 'restaurant, eating house, eating place, eatery',
+ 'revolver, six-gun, six-shooter',
+ 'rifle',
+ 'rocking chair, rocker',
+ 'rotisserie',
+ 'rubber eraser, rubber, pencil eraser',
+ 'rugby ball',
+ 'rule, ruler',
+ 'running shoe',
+ 'safe',
+ 'safety pin',
+ 'saltshaker, salt shaker',
+ 'sandal',
+ 'sarong',
+ 'sax, saxophone',
+ 'scabbard',
+ 'scale, weighing machine',
+ 'school bus',
+ 'schooner',
+ 'scoreboard',
+ 'screen, CRT screen',
+ 'screw',
+ 'screwdriver',
+ 'seat belt, seatbelt',
+ 'sewing machine',
+ 'shield, buckler',
+ 'shoe shop, shoe-shop, shoe store',
+ 'shoji',
+ 'shopping basket',
+ 'shopping cart',
+ 'shovel',
+ 'shower cap',
+ 'shower curtain',
+ 'ski',
+ 'ski mask',
+ 'sleeping bag',
+ 'slide rule, slipstick',
+ 'sliding door',
+ 'slot, one-armed bandit',
+ 'snorkel',
+ 'snowmobile',
+ 'snowplow, snowplough',
+ 'soap dispenser',
+ 'soccer ball',
+ 'sock',
+ 'solar dish, solar collector, solar furnace',
+ 'sombrero',
+ 'soup bowl',
+ 'space bar',
+ 'space heater',
+ 'space shuttle',
+ 'spatula',
+ 'speedboat',
+ "spider web, spider's web",
+ 'spindle',
+ 'sports car, sport car',
+ 'spotlight, spot',
+ 'stage',
+ 'steam locomotive',
+ 'steel arch bridge',
+ 'steel drum',
+ 'stethoscope',
+ 'stole',
+ 'stone wall',
+ 'stopwatch, stop watch',
+ 'stove',
+ 'strainer',
+ 'streetcar, tram, tramcar, trolley, trolley car',
+ 'stretcher',
+ 'studio couch, day bed',
+ 'stupa, tope',
+ 'submarine, pigboat, sub, U-boat',
+ 'suit, suit of clothes',
+ 'sundial',
+ 'sunglass',
+ 'sunglasses, dark glasses, shades',
+ 'sunscreen, sunblock, sun blocker',
+ 'suspension bridge',
+ 'swab, swob, mop',
+ 'sweatshirt',
+ 'swimming trunks, bathing trunks',
+ 'swing',
+ 'switch, electric switch, electrical switch',
+ 'syringe',
+ 'table lamp',
+ 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 'tape player',
+ 'teapot',
+ 'teddy, teddy bear',
+ 'television, television system',
+ 'tennis ball',
+ 'thatch, thatched roof',
+ 'theater curtain, theatre curtain',
+ 'thimble',
+ 'thresher, thrasher, threshing machine',
+ 'throne',
+ 'tile roof',
+ 'toaster',
+ 'tobacco shop, tobacconist shop, tobacconist',
+ 'toilet seat',
+ 'torch',
+ 'totem pole',
+ 'tow truck, tow car, wrecker',
+ 'toyshop',
+ 'tractor',
+ 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501
+ 'tray',
+ 'trench coat',
+ 'tricycle, trike, velocipede',
+ 'trimaran',
+ 'tripod',
+ 'triumphal arch',
+ 'trolleybus, trolley coach, trackless trolley',
+ 'trombone',
+ 'tub, vat',
+ 'turnstile',
+ 'typewriter keyboard',
+ 'umbrella',
+ 'unicycle, monocycle',
+ 'upright, upright piano',
+ 'vacuum, vacuum cleaner',
+ 'vase',
+ 'vault',
+ 'velvet',
+ 'vending machine',
+ 'vestment',
+ 'viaduct',
+ 'violin, fiddle',
+ 'volleyball',
+ 'waffle iron',
+ 'wall clock',
+ 'wallet, billfold, notecase, pocketbook',
+ 'wardrobe, closet, press',
+ 'warplane, military plane',
+ 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 'washer, automatic washer, washing machine',
+ 'water bottle',
+ 'water jug',
+ 'water tower',
+ 'whiskey jug',
+ 'whistle',
+ 'wig',
+ 'window screen',
+ 'window shade',
+ 'Windsor tie',
+ 'wine bottle',
+ 'wing',
+ 'wok',
+ 'wooden spoon',
+ 'wool, woolen, woollen',
+ 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 'wreck',
+ 'yawl',
+ 'yurt',
+ 'web site, website, internet site, site',
+ 'comic book',
+ 'crossword puzzle, crossword',
+ 'street sign',
+ 'traffic light, traffic signal, stoplight',
+ 'book jacket, dust cover, dust jacket, dust wrapper',
+ 'menu',
+ 'plate',
+ 'guacamole',
+ 'consomme',
+ 'hot pot, hotpot',
+ 'trifle',
+ 'ice cream, icecream',
+ 'ice lolly, lolly, lollipop, popsicle',
+ 'French loaf',
+ 'bagel, beigel',
+ 'pretzel',
+ 'cheeseburger',
+ 'hotdog, hot dog, red hot',
+ 'mashed potato',
+ 'head cabbage',
+ 'broccoli',
+ 'cauliflower',
+ 'zucchini, courgette',
+ 'spaghetti squash',
+ 'acorn squash',
+ 'butternut squash',
+ 'cucumber, cuke',
+ 'artichoke, globe artichoke',
+ 'bell pepper',
+ 'cardoon',
+ 'mushroom',
+ 'Granny Smith',
+ 'strawberry',
+ 'orange',
+ 'lemon',
+ 'fig',
+ 'pineapple, ananas',
+ 'banana',
+ 'jackfruit, jak, jack',
+ 'custard apple',
+ 'pomegranate',
+ 'hay',
+ 'carbonara',
+ 'chocolate sauce, chocolate syrup',
+ 'dough',
+ 'meat loaf, meatloaf',
+ 'pizza, pizza pie',
+ 'potpie',
+ 'burrito',
+ 'red wine',
+ 'espresso',
+ 'cup',
+ 'eggnog',
+ 'alp',
+ 'bubble',
+ 'cliff, drop, drop-off',
+ 'coral reef',
+ 'geyser',
+ 'lakeside, lakeshore',
+ 'promontory, headland, head, foreland',
+ 'sandbar, sand bar',
+ 'seashore, coast, seacoast, sea-coast',
+ 'valley, vale',
+ 'volcano',
+ 'ballplayer, baseball player',
+ 'groom, bridegroom',
+ 'scuba diver',
+ 'rapeseed',
+ 'daisy',
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501
+ 'corn',
+ 'acorn',
+ 'hip, rose hip, rosehip',
+ 'buckeye, horse chestnut, conker',
+ 'coral fungus',
+ 'agaric',
+ 'gyromitra',
+ 'stinkhorn, carrion fungus',
+ 'earthstar',
+ 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501
+ 'bolete',
+ 'ear, spike, capitulum',
+ 'toilet tissue, toilet paper, bathroom tissue')
+
+CIFAR10_CATEGORIES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
+ 'frog', 'horse', 'ship', 'truck')
+
+CIFAR100_CATEGORIES = (
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
+ 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
+ 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
+ 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
+ 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
+ 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
+ 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
+ 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
+ 'woman', 'worm')
+
+MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
+ '5 - five', '6 - six', '7 - seven', '8 - eight',
+ '9 - nine')
+
+FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
+ 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
+ 'Ankle boot')
+
+PLACES205_CATEGORIES = (
+ 'abbey', 'airport_terminal', 'alley', 'amphitheater', 'amusement_park',
+ 'aquarium', 'aqueduct', 'arch', 'art_gallery', 'art_studio',
+ 'assembly_line', 'attic', 'auditorium', 'apartment_building/outdoor',
+ 'badlands', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar',
+ 'baseball_field', 'basement', 'basilica', 'bayou', 'beauty_salon',
+ 'bedroom', 'boardwalk', 'boat_deck', 'bookstore', 'botanical_garden',
+ 'bowling_alley', 'boxing_ring', 'bridge', 'building_facade',
+ 'bus_interior', 'butchers_shop', 'butte', 'bakery/shop', 'cafeteria',
+ 'campsite', 'candy_store', 'canyon', 'castle', 'cemetery', 'chalet',
+ 'classroom', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop',
+ 'conference_center', 'conference_room', 'construction_site', 'corn_field',
+ 'corridor', 'cottage_garden', 'courthouse', 'courtyard', 'creek',
+ 'crevasse', 'crosswalk', 'cathedral/outdoor', 'church/outdoor', 'dam',
+ 'dining_room', 'dock', 'dorm_room', 'driveway', 'desert/sand',
+ 'desert/vegetation', 'dinette/home', 'doorway/outdoor', 'engine_room',
+ 'excavation', 'fairway', 'fire_escape', 'fire_station', 'food_court',
+ 'forest_path', 'forest_road', 'formal_garden', 'fountain',
+ 'field/cultivated', 'field/wild', 'galley', 'game_room', 'garbage_dump',
+ 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'herb_garden',
+ 'highway', 'home_office', 'hospital', 'hospital_room', 'hot_spring',
+ 'hotel_room', 'hotel/outdoor', 'ice_cream_parlor', 'iceberg', 'igloo',
+ 'islet', 'ice_skating_rink/outdoor', 'inn/outdoor', 'jail_cell', 'kasbah',
+ 'kindergarden_classroom', 'kitchen', 'kitchenette', 'laundromat',
+ 'lighthouse', 'living_room', 'lobby', 'locker_room', 'mansion', 'marsh',
+ 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain',
+ 'mountain_snowy', 'music_studio', 'market/outdoor', 'monastery/outdoor',
+ 'museum/indoor', 'nursery', 'ocean', 'office', 'office_building',
+ 'orchard', 'pagoda', 'palace', 'pantry', 'parking_lot', 'parlor',
+ 'pasture', 'patio', 'pavilion', 'phone_booth', 'picnic_area', 'playground',
+ 'plaza', 'pond', 'pulpit', 'racecourse', 'raft', 'railroad_track',
+ 'rainforest', 'reception', 'residential_neighborhood', 'restaurant',
+ 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'river',
+ 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sandbar', 'schoolhouse',
+ 'sea_cliff', 'shed', 'shoe_shop', 'shopfront', 'shower', 'ski_resort',
+ 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'staircase',
+ 'supermarket', 'swamp', 'stadium/baseball', 'stadium/football',
+ 'stage/indoor', 'subway_station/platform', 'swimming_pool/outdoor',
+ 'television_studio', 'topiary_garden', 'tower', 'train_railway',
+ 'tree_farm', 'trench', 'temple/east_asia', 'temple/south_asia',
+ 'track/outdoor', 'train_station/platform', 'underwater/coral_reef',
+ 'valley', 'vegetable_garden', 'veranda', 'viaduct', 'volcano',
+ 'waiting_room', 'water_tower', 'watering_hole', 'wheat_field', 'wind_farm',
+ 'windmill', 'yard')
+
+OxfordIIITPet_CATEGORIES = (
+ 'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier',
+ 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer',
+ 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel',
+ 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese',
+ 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon',
+ 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug',
+ 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier',
+ 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier',
+ 'wheaten_terrier', 'yorkshire_terrier')
+
+DTD_CATEGORIES = ('banded', 'blotchy', 'braided', 'bubbly', 'bumpy',
+ 'chequered', 'cobwebbed', 'cracked', 'crosshatched',
+ 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled',
+ 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed',
+ 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled',
+ 'matted', 'meshed', 'paisley', 'perforated', 'pitted',
+ 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly',
+ 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified',
+ 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven',
+ 'wrinkled', 'zigzagged')
+
+FGVCAIRCRAFT_CATEGORIES = (
+ '707-320', '727-200', '737-200', '737-300', '737-400', '737-500',
+ '737-600', '737-700', '737-800', '737-900', '747-100', '747-200',
+ '747-300', '747-400', '757-200', '757-300', '767-200', '767-300',
+ '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320',
+ 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500',
+ 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200',
+ 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47',
+ 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525',
+ 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30',
+ 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400',
+ 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145',
+ 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18',
+ 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70',
+ 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76',
+ 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200',
+ 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134',
+ 'Tu-154', 'Yak-42')
+
+STANFORDCARS_CATEGORIES = (
+ 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', 'Acura TL Sedan 2012',
+ 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012',
+ 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012',
+ 'Aston Martin V8 Vantage Convertible 2012',
+ 'Aston Martin V8 Vantage Coupe 2012',
+ 'Aston Martin Virage Convertible 2012', 'Aston Martin Virage Coupe 2012',
+ 'Audi RS 4 Convertible 2008', 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012',
+ 'Audi R8 Coupe 2012', 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994',
+ 'Audi 100 Wagon 1994', 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011',
+ 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012',
+ 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012',
+ 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012',
+ 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012',
+ 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007',
+ 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012',
+ 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012',
+ 'BMW Z4 Convertible 2012',
+ 'Bentley Continental Supersports Conv. Convertible 2012',
+ 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011',
+ 'Bentley Continental GT Coupe 2012', 'Bentley Continental GT Coupe 2007',
+ 'Bentley Continental Flying Spur Sedan 2007',
+ 'Bugatti Veyron 16.4 Convertible 2009', 'Bugatti Veyron 16.4 Coupe 2009',
+ 'Buick Regal GS 2012', 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012',
+ 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012',
+ 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007',
+ 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
+ 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012',
+ 'Chevrolet Corvette Ron Fellows Edition Z06 2007',
+ 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012',
+ 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007',
+ 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012',
+ 'Chevrolet Express Cargo Van 2007', 'Chevrolet Avalanche Crew Cab 2012',
+ 'Chevrolet Cobalt SS 2010', 'Chevrolet Malibu Hybrid Sedan 2010',
+ 'Chevrolet TrailBlazer SS 2009',
+ 'Chevrolet Silverado 2500HD Regular Cab 2012',
+ 'Chevrolet Silverado 1500 Classic Extended Cab 2007',
+ 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007',
+ 'Chevrolet Malibu Sedan 2007',
+ 'Chevrolet Silverado 1500 Extended Cab 2012',
+ 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009',
+ 'Chrysler Sebring Convertible 2010',
+ 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010',
+ 'Chrysler Crossfire Convertible 2008',
+ 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002',
+ 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007',
+ 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010',
+ 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009',
+ 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010',
+ 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008',
+ 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012',
+ 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012',
+ 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998',
+ 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012',
+ 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012',
+ 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012',
+ 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012',
+ 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007',
+ 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012',
+ 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006',
+ 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007',
+ 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012',
+ 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', 'GMC Savana Van 2012',
+ 'GMC Yukon Hybrid SUV 2012', 'GMC Acadia SUV 2012',
+ 'GMC Canyon Extended Cab 2012', 'Geo Metro Convertible 1993',
+ 'HUMMER H3T Crew Cab 2010', 'HUMMER H2 SUT Crew Cab 2009',
+ 'Honda Odyssey Minivan 2012', 'Honda Odyssey Minivan 2007',
+ 'Honda Accord Coupe 2012', 'Honda Accord Sedan 2012',
+ 'Hyundai Veloster Hatchback 2012', 'Hyundai Santa Fe SUV 2012',
+ 'Hyundai Tucson SUV 2012', 'Hyundai Veracruz SUV 2012',
+ 'Hyundai Sonata Hybrid Sedan 2012', 'Hyundai Elantra Sedan 2007',
+ 'Hyundai Accent Sedan 2012', 'Hyundai Genesis Sedan 2012',
+ 'Hyundai Sonata Sedan 2012', 'Hyundai Elantra Touring Hatchback 2012',
+ 'Hyundai Azera Sedan 2012', 'Infiniti G Coupe IPL 2012',
+ 'Infiniti QX56 SUV 2011', 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012',
+ 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', 'Jeep Liberty SUV 2012',
+ 'Jeep Grand Cherokee SUV 2012', 'Jeep Compass SUV 2012',
+ 'Lamborghini Reventon Coupe 2008', 'Lamborghini Aventador Coupe 2012',
+ 'Lamborghini Gallardo LP 570-4 Superleggera 2012',
+ 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012',
+ 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011',
+ 'MINI Cooper Roadster Convertible 2012',
+ 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011',
+ 'McLaren MP4-12C Coupe 2012', 'Mercedes-Benz 300-Class Convertible 1993',
+ 'Mercedes-Benz C-Class Sedan 2012', 'Mercedes-Benz SL-Class Coupe 2009',
+ 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012',
+ 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012',
+ 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012',
+ 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998',
+ 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012',
+ 'Ram C/V Cargo Van Minivan 2012',
+ 'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
+ 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012',
+ 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009',
+ 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007',
+ 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012',
+ 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012',
+ 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012',
+ 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012',
+ 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991',
+ 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012',
+ 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007',
+ 'smart fortwo Convertible 2012')
+
+SUN397_CATEGORIES = (
+ 'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater',
+ 'amusement_arcade', 'amusement_park', 'anechoic_chamber',
+ 'apartment_building_outdoor', 'apse_indoor', 'aquarium', 'aqueduct',
+ 'arch', 'archive', 'arrival_gate_outdoor', 'art_gallery', 'art_school',
+ 'art_studio', 'assembly_line', 'athletic_field_outdoor', 'atrium_public',
+ 'attic', 'auditorium', 'auto_factory', 'badlands',
+ 'badminton_court_indoor', 'baggage_claim', 'bakery_shop',
+ 'balcony_exterior', 'balcony_interior', 'ball_pit', 'ballroom',
+ 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor',
+ 'baseball_field', 'basement', 'basilica', 'basketball_court_outdoor',
+ 'bathroom', 'batters_box', 'bayou', 'bazaar_indoor', 'bazaar_outdoor',
+ 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory',
+ 'bistro_indoor', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore',
+ 'booth_indoor', 'botanical_garden', 'bow_window_indoor',
+ 'bow_window_outdoor', 'bowling_alley', 'boxing_ring', 'brewery_indoor',
+ 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior',
+ 'butchers_shop', 'butte', 'cabin_outdoor', 'cafeteria', 'campsite',
+ 'campus', 'canal_natural', 'canal_urban', 'candy_store', 'canyon',
+ 'car_interior_backseat', 'car_interior_frontseat', 'carrousel',
+ 'casino_indoor', 'castle', 'catacomb', 'cathedral_indoor',
+ 'cathedral_outdoor', 'cavern_indoor', 'cemetery', 'chalet',
+ 'cheese_factory', 'chemistry_lab', 'chicken_coop_indoor',
+ 'chicken_coop_outdoor', 'childs_room', 'church_indoor', 'church_outdoor',
+ 'classroom', 'clean_room', 'cliff', 'cloister_indoor', 'closet',
+ 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room',
+ 'conference_center', 'conference_room', 'construction_site',
+ 'control_room', 'control_tower_outdoor', 'corn_field', 'corral',
+ 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard',
+ 'covered_bridge_exterior', 'creek', 'crevasse', 'crosswalk',
+ 'cubicle_office', 'dam', 'delicatessen', 'dentists_office', 'desert_sand',
+ 'desert_vegetation', 'diner_indoor', 'diner_outdoor', 'dinette_home',
+ 'dinette_vehicle', 'dining_car', 'dining_room', 'discotheque', 'dock',
+ 'doorway_outdoor', 'dorm_room', 'driveway', 'driving_range_outdoor',
+ 'drugstore', 'electrical_substation', 'elevator_door', 'elevator_interior',
+ 'elevator_shaft', 'engine_room', 'escalator_indoor', 'excavation',
+ 'factory_indoor', 'fairway', 'fastfood_restaurant', 'field_cultivated',
+ 'field_wild', 'fire_escape', 'fire_station', 'firing_range_indoor',
+ 'fishpond', 'florist_shop_indoor', 'food_court', 'forest_broadleaf',
+ 'forest_needleleaf', 'forest_path', 'forest_road', 'formal_garden',
+ 'fountain', 'galley', 'game_room', 'garage_indoor', 'garbage_dump',
+ 'gas_station', 'gazebo_exterior', 'general_store_indoor',
+ 'general_store_outdoor', 'gift_shop', 'golf_course', 'greenhouse_indoor',
+ 'greenhouse_outdoor', 'gymnasium_indoor', 'hangar_indoor',
+ 'hangar_outdoor', 'harbor', 'hayfield', 'heliport', 'herb_garden',
+ 'highway', 'hill', 'home_office', 'hospital', 'hospital_room',
+ 'hot_spring', 'hot_tub_outdoor', 'hotel_outdoor', 'hotel_room', 'house',
+ 'hunting_lodge_outdoor', 'ice_cream_parlor', 'ice_floe', 'ice_shelf',
+ 'ice_skating_rink_indoor', 'ice_skating_rink_outdoor', 'iceberg', 'igloo',
+ 'industrial_area', 'inn_outdoor', 'islet', 'jacuzzi_indoor', 'jail_indoor',
+ 'jail_cell', 'jewelry_shop', 'kasbah', 'kennel_indoor', 'kennel_outdoor',
+ 'kindergarden_classroom', 'kitchen', 'kitchenette', 'labyrinth_outdoor',
+ 'lake_natural', 'landfill', 'landing_deck', 'laundromat', 'lecture_room',
+ 'library_indoor', 'library_outdoor', 'lido_deck_outdoor', 'lift_bridge',
+ 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber',
+ 'locker_room', 'mansion', 'manufactured_home', 'market_indoor',
+ 'market_outdoor', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina',
+ 'moat_water', 'monastery_outdoor', 'mosque_indoor', 'mosque_outdoor',
+ 'motel', 'mountain', 'mountain_snowy', 'movie_theater_indoor',
+ 'museum_indoor', 'music_store', 'music_studio',
+ 'nuclear_power_plant_outdoor', 'nursery', 'oast_house',
+ 'observatory_outdoor', 'ocean', 'office', 'office_building',
+ 'oil_refinery_outdoor', 'oilrig', 'operating_room', 'orchard',
+ 'outhouse_outdoor', 'pagoda', 'palace', 'pantry', 'park',
+ 'parking_garage_indoor', 'parking_garage_outdoor', 'parking_lot', 'parlor',
+ 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth',
+ 'physics_laboratory', 'picnic_area', 'pilothouse_indoor',
+ 'planetarium_outdoor', 'playground', 'playroom', 'plaza', 'podium_indoor',
+ 'podium_outdoor', 'pond', 'poolroom_establishment', 'poolroom_home',
+ 'power_plant_outdoor', 'promenade_deck', 'pub_indoor', 'pulpit',
+ 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track',
+ 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood',
+ 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy',
+ 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway',
+ 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room',
+ 'shed', 'shoe_shop', 'shopfront', 'shopping_mall_indoor', 'shower',
+ 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper',
+ 'slum', 'snowfield', 'squash_court', 'stable', 'stadium_baseball',
+ 'stadium_football', 'stage_indoor', 'staircase', 'street',
+ 'subway_interior', 'subway_station_platform', 'supermarket', 'sushi_bar',
+ 'swamp', 'swimming_pool_indoor', 'swimming_pool_outdoor',
+ 'synagogue_indoor', 'synagogue_outdoor', 'television_studio',
+ 'temple_east_asia', 'temple_south_asia', 'tennis_court_indoor',
+ 'tennis_court_outdoor', 'tent_outdoor', 'theater_indoor_procenium',
+ 'theater_indoor_seats', 'thriftshop', 'throne_room', 'ticket_booth',
+ 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'track_outdoor',
+ 'train_railway', 'train_station_platform', 'tree_farm', 'tree_house',
+ 'trench', 'underwater_coral_reef', 'utility_room', 'valley',
+ 'van_interior', 'vegetable_garden', 'veranda', 'veterinarians_office',
+ 'viaduct', 'videostore', 'village', 'vineyard', 'volcano',
+ 'volleyball_court_indoor', 'volleyball_court_outdoor', 'waiting_room',
+ 'warehouse_indoor', 'water_tower', 'waterfall_block', 'waterfall_fan',
+ 'waterfall_plunge', 'watering_hole', 'wave', 'wet_bar', 'wheat_field',
+ 'wind_farm', 'windmill', 'wine_cellar_barrel_storage',
+ 'wine_cellar_bottle_storage', 'wrestling_ring_indoor', 'yard',
+ 'youth_hostel')
+
+CALTECH101_CATEGORIES = (
+ 'BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes',
+ 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver',
+ 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly',
+ 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair',
+ 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish',
+ 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill',
+ 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium',
+ 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk',
+ 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog',
+ 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch',
+ 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
+ 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi',
+ 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver',
+ 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion',
+ 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus',
+ 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella',
+ 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair',
+ 'wrench', 'yin_yang')
+
+FOOD101_CATEGORIES = (
+ 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
+ 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
+ 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
+ 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry',
+ 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake',
+ 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich',
+ 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs',
+ 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel',
+ 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries',
+ 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
+ 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad',
+ 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza',
+ 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus',
+ 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich',
+ 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos',
+ 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes',
+ 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine',
+ 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake',
+ 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad',
+ 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
+ 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos',
+ 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles')
+
+CIFAR100_CATEGORIES_CN = (
+ '苹果', '水族馆鱼', '婴儿', '熊', '河狸', '床', '蜜蜂', '甲虫', '自行车', '瓶子', '碗', '小男孩',
+ '桥', '公共汽车', '蝴蝶', '骆驼', '易拉罐', '城堡', '毛毛虫', '牛', '椅子', '猩猩', '钟', '白云',
+ '蟑螂', '沙发', '螃蟹', '鳄鱼', '杯子', '恐龙', '海豚', '大象', '比目鱼', '森林', '狐狸', '小女孩',
+ '仓鼠', '屋子', '袋鼠', '键盘', '台灯', '割草机', '猎豹', '狮子', '蜥蜴', '龙虾', '男人', '枫树',
+ '摩托车', '山', '老鼠', '蘑菇', '橡树', '橙子橘子', '兰花', '水獭', '棕榈树', '梨', '皮卡车', '松树',
+ '田野', '盘子', '罂粟', '豪猪', '负鼠', '兔子', '浣熊', '鳐鱼', '公路', '火箭', '玫瑰', '大海',
+ '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
+ '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
+ '柳树', '狼', '女人', '蠕虫')
diff --git a/mmpretrain/datasets/cifar.py b/mmpretrain/datasets/cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a011daee0d74e6b06613106f7587b8ad8a7ed90
--- /dev/null
+++ b/mmpretrain/datasets/cifar.py
@@ -0,0 +1,210 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+from typing import List, Optional
+
+import mmengine.dist as dist
+import numpy as np
+from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
+ join_path)
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import CIFAR10_CATEGORIES, CIFAR100_CATEGORIES
+from .utils import check_md5, download_and_extract_archive
+
+
+@DATASETS.register_module()
+class CIFAR10(BaseDataset):
+ """`CIFAR10 `_ Dataset.
+
+ This implementation is modified from
+ https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py
+
+ Args:
+ data_root (str): The root directory of the CIFAR Dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+ metainfo (dict, optional): Meta information for dataset, such as
+ categories information. Defaults to None.
+ download (bool): Whether to download the dataset if not exists.
+ Defaults to True.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """ # noqa: E501
+
+ base_folder = 'cifar-10-batches-py'
+ url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+ filename = 'cifar-10-python.tar.gz'
+ tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
+ train_list = [
+ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
+ ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
+ ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
+ ['data_batch_4', '634d18415352ddfa80567beed471001a'],
+ ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
+ ]
+
+ test_list = [
+ ['test_batch', '40351d587109b95175f43aff81a1287e'],
+ ]
+ meta = {
+ 'filename': 'batches.meta',
+ 'key': 'label_names',
+ 'md5': '5ff9c542aee3614f3951f8cda6e48888',
+ }
+ METAINFO = {'classes': CIFAR10_CATEGORIES}
+
+ def __init__(self,
+ data_root: str = '',
+ split: str = 'train',
+ metainfo: Optional[dict] = None,
+ download: bool = True,
+ data_prefix: str = '',
+ test_mode: bool = False,
+ **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ # To handle the BC-breaking
+ if split == 'train' and test_mode:
+ logger = MMLogger.get_current_instance()
+ logger.warning('split="train" but test_mode=True. '
+ 'The training set will be used.')
+
+ if not data_root and not data_prefix:
+ raise RuntimeError('Please set ``data_root`` to'
+ 'specify the dataset path')
+
+ self.download = download
+ super().__init__(
+ # The CIFAR dataset doesn't need specify annotation file
+ ann_file='',
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=dict(root=data_prefix),
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+ root = self.data_prefix['root']
+ backend = get_file_backend(root, enable_singleton=True)
+
+ if dist.is_main_process() and not self._check_integrity():
+ if not isinstance(backend, LocalBackend):
+ raise RuntimeError(f'The dataset on {root} is not integrated, '
+ f'please manually handle it.')
+
+ if self.download:
+ download_and_extract_archive(
+ self.url, root, filename=self.filename, md5=self.tgz_md5)
+ else:
+ raise RuntimeError(
+ f'Cannot find {self.__class__.__name__} dataset in '
+ f"{self.data_prefix['root']}, you can specify "
+ '`download=True` to download automatically.')
+
+ dist.barrier()
+ assert self._check_integrity(), \
+ 'Download failed or shared storage is unavailable. Please ' \
+ f'download the dataset manually through {self.url}.'
+
+ if self.split == 'train':
+ downloaded_list = self.train_list
+ else:
+ downloaded_list = self.test_list
+
+ imgs = []
+ gt_labels = []
+
+ # load the picked numpy arrays
+ for file_name, _ in downloaded_list:
+ file_path = join_path(root, self.base_folder, file_name)
+ entry = pickle.loads(get(file_path), encoding='latin1')
+ imgs.append(entry['data'])
+ if 'labels' in entry:
+ gt_labels.extend(entry['labels'])
+ else:
+ gt_labels.extend(entry['fine_labels'])
+
+ imgs = np.vstack(imgs).reshape(-1, 3, 32, 32)
+ imgs = imgs.transpose((0, 2, 3, 1)) # convert to HWC
+
+ if self.CLASSES is None:
+ # The metainfo in the file has the lowest priority, therefore
+ # we only need to load it if classes is not specified.
+ self._load_meta()
+
+ data_list = []
+ for img, gt_label in zip(imgs, gt_labels):
+ info = {'img': img, 'gt_label': int(gt_label)}
+ data_list.append(info)
+ return data_list
+
+ def _load_meta(self):
+ """Load categories information from metafile."""
+ root = self.data_prefix['root']
+
+ path = join_path(root, self.base_folder, self.meta['filename'])
+ md5 = self.meta.get('md5', None)
+ if not exists(path) or (md5 is not None and not check_md5(path, md5)):
+ raise RuntimeError(
+ 'Dataset metadata file not found or corrupted.' +
+ ' You can use `download=True` to download it')
+ data = pickle.loads(get(path), encoding='latin1')
+ self._metainfo.setdefault('classes', data[self.meta['key']])
+
+ def _check_integrity(self):
+ """Check the integrity of data files."""
+ root = self.data_prefix['root']
+
+ for fentry in (self.train_list + self.test_list):
+ filename, md5 = fentry[0], fentry[1]
+ fpath = join_path(root, self.base_folder, filename)
+ if not exists(fpath):
+ return False
+ if md5 is not None and not check_md5(fpath, md5):
+ return False
+ return True
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [f"Prefix of data: \t{self.data_prefix['root']}"]
+ return body
+
+
+@DATASETS.register_module()
+class CIFAR100(CIFAR10):
+ """`CIFAR100 `_ Dataset.
+
+ Args:
+ data_root (str): The root directory of the CIFAR Dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+ metainfo (dict, optional): Meta information for dataset, such as
+ categories information. Defaults to None.
+ download (bool): Whether to download the dataset if not exists.
+ Defaults to True.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ base_folder = 'cifar-100-python'
+ url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
+ filename = 'cifar-100-python.tar.gz'
+ tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
+ train_list = [
+ ['train', '16019d7e3df5f24257cddd939b257f8d'],
+ ]
+
+ test_list = [
+ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
+ ]
+ meta = {
+ 'filename': 'meta',
+ 'key': 'fine_label_names',
+ 'md5': '7973b15100ade9c7d40fb424638fde48',
+ }
+ METAINFO = {'classes': CIFAR100_CATEGORIES}
diff --git a/mmpretrain/datasets/coco_caption.py b/mmpretrain/datasets/coco_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..541cda80398f7fcc7d3304d3d9f43155685ebe57
--- /dev/null
+++ b/mmpretrain/datasets/coco_caption.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+from mmengine.fileio import get_file_backend
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class COCOCaption(BaseDataset):
+ """COCO Caption dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``..
+ ann_file (str): Annotation file path.
+ data_prefix (dict): Prefix for data field. Defaults to
+ ``dict(img_path='')``.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ img_prefix = self.data_prefix['img_path']
+ annotations = mmengine.load(self.ann_file)
+ file_backend = get_file_backend(img_prefix)
+
+ data_list = []
+ for ann in annotations:
+ data_info = {
+ 'image_id': Path(ann['image']).stem.split('_')[-1],
+ 'img_path': file_backend.join_path(img_prefix, ann['image']),
+ 'gt_caption': ann['caption'],
+ }
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..60d1586ad8672a4b57fcdc62740b3e08c3e2e20e
--- /dev/null
+++ b/mmpretrain/datasets/coco_retrieval.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+from collections import OrderedDict
+from typing import List
+
+from mmengine import get_file_backend
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class COCORetrieval(BaseDataset):
+ """COCO Retrieval dataset.
+
+ Args:
+ ann_file (str): Annotation file path.
+ test_mode (bool): Whether dataset is used for evaluation. This will
+ decide the annotation format in data list annotations.
+ Defaults to False.
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ # get file backend
+ img_prefix = self.data_prefix['img_path']
+ file_backend = get_file_backend(img_prefix)
+
+ anno_info = json.load(open(self.ann_file, 'r'))
+ # mapping img_id to img filename
+ img_dict = OrderedDict()
+ for idx, img in enumerate(anno_info['images']):
+ if img['id'] not in img_dict:
+ img_rel_path = img['coco_url'].rsplit('/', 2)[-2:]
+ img_path = file_backend.join_path(img_prefix, *img_rel_path)
+
+ # create new idx for image
+ img_dict[img['id']] = dict(
+ ori_id=img['id'],
+ image_id=idx, # will be used for evaluation
+ img_path=img_path,
+ text=[],
+ gt_text_id=[],
+ gt_image_id=[],
+ )
+
+ train_list = []
+ for idx, anno in enumerate(anno_info['annotations']):
+ anno['text'] = anno.pop('caption')
+ anno['ori_id'] = anno.pop('id')
+ anno['text_id'] = idx # will be used for evaluation
+ # 1. prepare train data list item
+ train_data = anno.copy()
+ train_image = img_dict[train_data['image_id']]
+ train_data['img_path'] = train_image['img_path']
+ train_data['image_ori_id'] = train_image['ori_id']
+ train_data['image_id'] = train_image['image_id']
+ train_data['is_matched'] = True
+ train_list.append(train_data)
+ # 2. prepare eval data list item based on img dict
+ img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id'])
+ img_dict[anno['image_id']]['text'].append(anno['text'])
+ img_dict[anno['image_id']]['gt_image_id'].append(
+ train_image['image_id'])
+
+ self.img_size = len(img_dict)
+ self.text_size = len(anno_info['annotations'])
+
+ # return needed format data list
+ if self.test_mode:
+ return list(img_dict.values())
+ return train_list
diff --git a/mmpretrain/datasets/coco_vqa.py b/mmpretrain/datasets/coco_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..85f4bdcf39ef82ec47a2072dc198e6b8792d8768
--- /dev/null
+++ b/mmpretrain/datasets/coco_vqa.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import re
+from collections import Counter
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class COCOVQA(BaseDataset):
+ """VQAv2 dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ question_file (str): Question file path.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to an empty string.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str,
+ question_file: str,
+ ann_file: str = '',
+ **kwarg):
+ self.question_file = question_file
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def _join_prefix(self):
+ if not mmengine.is_abs(self.question_file) and self.question_file:
+ self.question_file = osp.join(self.data_root, self.question_file)
+
+ return super()._join_prefix()
+
+ def _create_image_index(self):
+ img_prefix = self.data_prefix['img_path']
+
+ files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
+ image_index = {}
+ for file in files:
+ image_id = re.findall(r'\d{12}', file)
+ if len(image_id) > 0:
+ image_id = int(image_id[-1])
+ image_index[image_id] = mmengine.join_path(img_prefix, file)
+
+ return image_index
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ questions = mmengine.load(self.question_file)['questions']
+ if self.ann_file:
+ annotations = mmengine.load(self.ann_file)['annotations']
+ assert len(questions) == len(annotations)
+ else:
+ annotations = [None] * len(questions)
+
+ # The original VQAv2 annotation file and question file includes
+ # only image id but no image file paths.
+ self.image_index = self._create_image_index()
+
+ data_list = []
+ for question, ann in zip(questions, annotations):
+ # question example
+ # {
+ # 'image_id': 262144,
+ # 'question': "Is the ball flying towards the batter?",
+ # 'question_id': 262144000
+ # }
+ #
+ # ann example
+ # {
+ # 'question_type': "what are the",
+ # 'answer_type': "other",
+ # 'answers': [
+ # {'answer': 'watching',
+ # 'answer_id': 1,
+ # 'answer_confidence': 'yes'},
+ # ...
+ # ],
+ # 'image_id': 262148,
+ # 'question_id': 262148000,
+ # 'multiple_choice_answer': 'watching',
+ # 'answer_type': 'other',
+ # }
+
+ data_info = question
+ data_info['img_path'] = self.image_index[question['image_id']]
+
+ if ann is not None:
+ assert ann['question_id'] == question['question_id']
+
+ # add answer_weight & answer_count, delete duplicate answer
+ answers = [item['answer'] for item in ann.pop('answers')]
+ count = Counter(answers)
+ answer_weight = [i / len(answers) for i in count.values()]
+ data_info['gt_answer'] = list(count.keys())
+ data_info['gt_answer_weight'] = answer_weight
+ data_info.update(ann)
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/cub.py b/mmpretrain/datasets/cub.py
new file mode 100644
index 0000000000000000000000000000000000000000..8db126216fb3408e2dd18255db04a851eb5fe08f
--- /dev/null
+++ b/mmpretrain/datasets/cub.py
@@ -0,0 +1,142 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import CUB_CATEGORIES
+
+
+@DATASETS.register_module()
+class CUB(BaseDataset):
+ """The CUB-200-2011 Dataset.
+
+ Support the `CUB-200-2011 `_ Dataset.
+ Comparing with the `CUB-200 `_ Dataset,
+ there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset
+ directory structure is as follows.
+
+ CUB dataset directory: ::
+
+ CUB_200_2011
+ ├── images
+ │ ├── class_x
+ │ │ ├── xx1.jpg
+ │ │ ├── xx2.jpg
+ │ │ └── ...
+ │ ├── class_y
+ │ │ ├── yy1.jpg
+ │ │ ├── yy2.jpg
+ │ │ └── ...
+ │ └── ...
+ ├── images.txt
+ ├── image_class_labels.txt
+ ├── train_test_split.txt
+ └── ....
+
+ Args:
+ data_root (str): The root directory for CUB-200-2011 dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+
+ Examples:
+ >>> from mmpretrain.datasets import CUB
+ >>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train')
+ >>> train_dataset
+ Dataset CUB
+ Number of samples: 5994
+ Number of categories: 200
+ Root of dataset: data/CUB_200_2011
+ >>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test')
+ >>> test_dataset
+ Dataset CUB
+ Number of samples: 5794
+ Number of categories: 200
+ Root of dataset: data/CUB_200_2011
+ """ # noqa: E501
+
+ METAINFO = {'classes': CUB_CATEGORIES}
+
+ def __init__(self,
+ data_root: str,
+ split: str = 'train',
+ test_mode: bool = False,
+ **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ # To handle the BC-breaking
+ if split == 'train' and test_mode:
+ logger = MMLogger.get_current_instance()
+ logger.warning('split="train" but test_mode=True. '
+ 'The training set will be used.')
+
+ ann_file = 'images.txt'
+ data_prefix = 'images'
+ image_class_labels_file = 'image_class_labels.txt'
+ train_test_split_file = 'train_test_split.txt'
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ self.image_class_labels_file = self.backend.join_path(
+ data_root, image_class_labels_file)
+ self.train_test_split_file = self.backend.join_path(
+ data_root, train_test_split_file)
+ super(CUB, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def _load_data_from_txt(self, filepath):
+ """load data from CUB txt file, the every line of the file is idx and a
+ data item."""
+ pairs = list_from_file(filepath)
+ data_dict = dict()
+ for pair in pairs:
+ idx, data_item = pair.split()
+ # all the index starts from 1 in CUB files,
+ # here we need to '- 1' to let them start from 0.
+ data_dict[int(idx) - 1] = data_item
+ return data_dict
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+ sample_dict = self._load_data_from_txt(self.ann_file)
+
+ label_dict = self._load_data_from_txt(self.image_class_labels_file)
+
+ split_dict = self._load_data_from_txt(self.train_test_split_file)
+
+ assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\
+ f'sample_ids should be same in files {self.ann_file}, ' \
+ f'{self.image_class_labels_file} and {self.train_test_split_file}'
+
+ data_list = []
+ for sample_id in sample_dict.keys():
+ if split_dict[sample_id] == '1' and self.split == 'test':
+ # skip train samples when split='test'
+ continue
+ elif split_dict[sample_id] == '0' and self.split == 'train':
+ # skip test samples when split='train'
+ continue
+
+ img_path = self.backend.join_path(self.img_prefix,
+ sample_dict[sample_id])
+ gt_label = int(label_dict[sample_id]) - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/custom.py b/mmpretrain/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb491ff0cc7f816f629603d3b8be55e3f787c373
--- /dev/null
+++ b/mmpretrain/datasets/custom.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+from mmengine.fileio import (BaseStorageBackend, get_file_backend,
+ list_from_file)
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+def find_folders(
+ root: str,
+ backend: Optional[BaseStorageBackend] = None
+) -> Tuple[List[str], Dict[str, int]]:
+ """Find classes by folders under a root.
+
+ Args:
+ root (string): root directory of folders
+ backend (BaseStorageBackend | None): The file backend of the root.
+ If None, auto infer backend from the root path. Defaults to None.
+
+ Returns:
+ Tuple[List[str], Dict[str, int]]:
+
+ - folders: The name of sub folders under the root.
+ - folder_to_idx: The map from folder name to class idx.
+ """
+ # Pre-build file backend to prevent verbose file backend inference.
+ backend = backend or get_file_backend(root, enable_singleton=True)
+ folders = list(
+ backend.list_dir_or_file(
+ root,
+ list_dir=True,
+ list_file=False,
+ recursive=False,
+ ))
+ folders.sort()
+ folder_to_idx = {folders[i]: i for i in range(len(folders))}
+ return folders, folder_to_idx
+
+
+def get_samples(
+ root: str,
+ folder_to_idx: Dict[str, int],
+ is_valid_file: Callable,
+ backend: Optional[BaseStorageBackend] = None,
+):
+ """Make dataset by walking all images under a root.
+
+ Args:
+ root (string): root directory of folders
+ folder_to_idx (dict): the map from class name to class idx
+ is_valid_file (Callable): A function that takes path of a file
+ and check if the file is a valid sample file.
+ backend (BaseStorageBackend | None): The file backend of the root.
+ If None, auto infer backend from the root path. Defaults to None.
+
+ Returns:
+ Tuple[list, set]:
+
+ - samples: a list of tuple where each element is (image, class_idx)
+ - empty_folders: The folders don't have any valid files.
+ """
+ samples = []
+ available_classes = set()
+ # Pre-build file backend to prevent verbose file backend inference.
+ backend = backend or get_file_backend(root, enable_singleton=True)
+
+ if folder_to_idx is not None:
+ for folder_name in sorted(list(folder_to_idx.keys())):
+ _dir = backend.join_path(root, folder_name)
+ files = backend.list_dir_or_file(
+ _dir,
+ list_dir=False,
+ list_file=True,
+ recursive=True,
+ )
+ for file in sorted(list(files)):
+ if is_valid_file(file):
+ path = backend.join_path(folder_name, file)
+ item = (path, folder_to_idx[folder_name])
+ samples.append(item)
+ available_classes.add(folder_name)
+ empty_folders = set(folder_to_idx.keys()) - available_classes
+ else:
+ files = backend.list_dir_or_file(
+ root,
+ list_dir=False,
+ list_file=True,
+ recursive=True,
+ )
+ samples = [file for file in sorted(list(files)) if is_valid_file(file)]
+ empty_folders = None
+
+ return samples, empty_folders
+
+
+@DATASETS.register_module()
+class CustomDataset(BaseDataset):
+ """A generic dataset for multiple tasks.
+
+ The dataset supports two kinds of style.
+
+ 1. Use an annotation file to specify all samples, and each line indicates a
+ sample:
+
+ The annotation file (for ``with_label=True``, supervised tasks.): ::
+
+ folder_1/xxx.png 0
+ folder_1/xxy.png 1
+ 123.png 4
+ nsdf3.png 3
+ ...
+
+ The annotation file (for ``with_label=False``, unsupervised tasks.): ::
+
+ folder_1/xxx.png
+ folder_1/xxy.png
+ 123.png
+ nsdf3.png
+ ...
+
+ Sample files: ::
+
+ data_prefix/
+ ├── folder_1
+ │ ├── xxx.png
+ │ ├── xxy.png
+ │ └── ...
+ ├── 123.png
+ ├── nsdf3.png
+ └── ...
+
+ Please use the argument ``metainfo`` to specify extra information for
+ the task, like ``{'classes': ('bird', 'cat', 'deer', 'dog', 'frog')}``.
+
+ 2. Place all samples in one folder as below:
+
+ Sample files (for ``with_label=True``, supervised tasks, we use the name
+ of sub-folders as the categories names): ::
+
+ data_prefix/
+ ├── class_x
+ │ ├── xxx.png
+ │ ├── xxy.png
+ │ └── ...
+ │ └── xxz.png
+ └── class_y
+ ├── 123.png
+ ├── nsdf3.png
+ ├── ...
+ └── asd932_.png
+
+ Sample files (for ``with_label=False``, unsupervised tasks, we use all
+ sample files under the specified folder): ::
+
+ data_prefix/
+ ├── folder_1
+ │ ├── xxx.png
+ │ ├── xxy.png
+ │ └── ...
+ ├── 123.png
+ ├── nsdf3.png
+ └── ...
+
+ If the ``ann_file`` is specified, the dataset will be generated by the
+ first way, otherwise, try the second way.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for the data. Defaults to ''.
+ ann_file (str): Annotation file path. Defaults to ''.
+ with_label (bool): Whether the annotation file includes ground truth
+ labels, or use sub-folders to specify categories.
+ Defaults to True.
+ extensions (Sequence[str]): A sequence of allowed extensions. Defaults
+ to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ lazy_init (bool): Whether to load annotation during instantiation.
+ In some cases, such as visualization, only the meta information of
+ the dataset is needed, which is not necessary to load annotation
+ file. ``Basedataset`` can skip load annotations to save time by set
+ ``lazy_init=False``. Defaults to False.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str = '',
+ data_prefix: Union[str, dict] = '',
+ ann_file: str = '',
+ with_label=True,
+ extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
+ '.bmp', '.pgm', '.tif'),
+ metainfo: Optional[dict] = None,
+ lazy_init: bool = False,
+ **kwargs):
+ assert (ann_file or data_prefix or data_root), \
+ 'One of `ann_file`, `data_root` and `data_prefix` must '\
+ 'be specified.'
+
+ self.extensions = tuple(set([i.lower() for i in extensions]))
+ self.with_label = with_label
+
+ super().__init__(
+ # The base class requires string ann_file but this class doesn't
+ ann_file=ann_file,
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ # Force to lazy_init for some modification before loading data.
+ lazy_init=True,
+ **kwargs)
+
+ # Full initialize the dataset.
+ if not lazy_init:
+ self.full_init()
+
+ def _find_samples(self):
+ """find samples from ``data_prefix``."""
+ if self.with_label:
+ classes, folder_to_idx = find_folders(self.img_prefix)
+ samples, empty_classes = get_samples(
+ self.img_prefix,
+ folder_to_idx,
+ is_valid_file=self.is_valid_file,
+ )
+
+ self.folder_to_idx = folder_to_idx
+
+ if self.CLASSES is not None:
+ assert len(self.CLASSES) == len(classes), \
+ f"The number of subfolders ({len(classes)}) doesn't " \
+ f'match the number of specified classes ' \
+ f'({len(self.CLASSES)}). Please check the data folder.'
+ else:
+ self._metainfo['classes'] = tuple(classes)
+ else:
+ samples, empty_classes = get_samples(
+ self.img_prefix,
+ None,
+ is_valid_file=self.is_valid_file,
+ )
+
+ if len(samples) == 0:
+ raise RuntimeError(
+ f'Found 0 files in subfolders of: {self.data_prefix}. '
+ f'Supported extensions are: {",".join(self.extensions)}')
+
+ if empty_classes:
+ logger = MMLogger.get_current_instance()
+ logger.warning(
+ 'Found no valid file in the folder '
+ f'{", ".join(empty_classes)}. '
+ f"Supported extensions are: {', '.join(self.extensions)}")
+
+ return samples
+
+ def load_data_list(self):
+ """Load image paths and gt_labels."""
+ if not self.ann_file:
+ samples = self._find_samples()
+ elif self.with_label:
+ lines = list_from_file(self.ann_file)
+ samples = [x.strip().rsplit(' ', 1) for x in lines]
+ else:
+ samples = list_from_file(self.ann_file)
+
+ # Pre-build file backend to prevent verbose file backend inference.
+ backend = get_file_backend(self.img_prefix, enable_singleton=True)
+ data_list = []
+ for sample in samples:
+ if self.with_label:
+ filename, gt_label = sample
+ img_path = backend.join_path(self.img_prefix, filename)
+ info = {'img_path': img_path, 'gt_label': int(gt_label)}
+ else:
+ img_path = backend.join_path(self.img_prefix, sample)
+ info = {'img_path': img_path}
+ data_list.append(info)
+ return data_list
+
+ def is_valid_file(self, filename: str) -> bool:
+ """Check if a file is a valid sample."""
+ return filename.lower().endswith(self.extensions)
diff --git a/mmpretrain/datasets/dataset_wrappers.py b/mmpretrain/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1adff10beb024940f9066a407cc76ddb06b27404
--- /dev/null
+++ b/mmpretrain/datasets/dataset_wrappers.py
@@ -0,0 +1,176 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import numpy as np
+from mmengine.dataset import BaseDataset, force_full_init
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class KFoldDataset:
+ """A wrapper of dataset for K-Fold cross-validation.
+
+ K-Fold cross-validation divides all the samples in groups of samples,
+ called folds, of almost equal sizes. And we use k-1 of folds to do training
+ and use the fold left to do validation.
+
+ Args:
+ dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
+ divided
+ fold (int): The fold used to do validation. Defaults to 0.
+ num_splits (int): The number of all folds. Defaults to 5.
+ test_mode (bool): Use the training dataset or validation dataset.
+ Defaults to False.
+ seed (int, optional): The seed to shuffle the dataset before splitting.
+ If None, not shuffle the dataset. Defaults to None.
+ """
+
+ def __init__(self,
+ dataset,
+ fold=0,
+ num_splits=5,
+ test_mode=False,
+ seed=None):
+ if isinstance(dataset, dict):
+ self.dataset = DATASETS.build(dataset)
+ # Init the dataset wrapper lazily according to the dataset setting.
+ lazy_init = dataset.get('lazy_init', False)
+ elif isinstance(dataset, BaseDataset):
+ self.dataset = dataset
+ else:
+ raise TypeError(f'Unsupported dataset type {type(dataset)}.')
+
+ self._metainfo = getattr(self.dataset, 'metainfo', {})
+ self.fold = fold
+ self.num_splits = num_splits
+ self.test_mode = test_mode
+ self.seed = seed
+
+ self._fully_initialized = False
+ if not lazy_init:
+ self.full_init()
+
+ @property
+ def metainfo(self) -> dict:
+ """Get the meta information of ``self.dataset``.
+
+ Returns:
+ dict: Meta information of the dataset.
+ """
+ # Prevent `self._metainfo` from being modified by outside.
+ return copy.deepcopy(self._metainfo)
+
+ def full_init(self):
+ """fully initialize the dataset."""
+ if self._fully_initialized:
+ return
+
+ self.dataset.full_init()
+ ori_len = len(self.dataset)
+ indices = list(range(ori_len))
+ if self.seed is not None:
+ rng = np.random.default_rng(self.seed)
+ rng.shuffle(indices)
+
+ test_start = ori_len * self.fold // self.num_splits
+ test_end = ori_len * (self.fold + 1) // self.num_splits
+ if self.test_mode:
+ indices = indices[test_start:test_end]
+ else:
+ indices = indices[:test_start] + indices[test_end:]
+
+ self._ori_indices = indices
+ self.dataset = self.dataset.get_subset(indices)
+
+ self._fully_initialized = True
+
+ @force_full_init
+ def _get_ori_dataset_idx(self, idx: int) -> int:
+ """Convert global idx to local index.
+
+ Args:
+ idx (int): Global index of ``KFoldDataset``.
+
+ Returns:
+ int: The original index in the whole dataset.
+ """
+ return self._ori_indices[idx]
+
+ @force_full_init
+ def get_data_info(self, idx: int) -> dict:
+ """Get annotation by index.
+
+ Args:
+ idx (int): Global index of ``KFoldDataset``.
+
+ Returns:
+ dict: The idx-th annotation of the datasets.
+ """
+ return self.dataset.get_data_info(idx)
+
+ @force_full_init
+ def __len__(self):
+ return len(self.dataset)
+
+ @force_full_init
+ def __getitem__(self, idx):
+ return self.dataset[idx]
+
+ @force_full_init
+ def get_cat_ids(self, idx):
+ return self.dataset.get_cat_ids(idx)
+
+ @force_full_init
+ def get_gt_labels(self):
+ return self.dataset.get_gt_labels()
+
+ @property
+ def CLASSES(self):
+ """Return all categories names."""
+ return self._metainfo.get('classes', None)
+
+ @property
+ def class_to_idx(self):
+ """Map mapping class name to class index.
+
+ Returns:
+ dict: mapping from class name to class index.
+ """
+
+ return {cat: i for i, cat in enumerate(self.CLASSES)}
+
+ def __repr__(self):
+ """Print the basic information of the dataset.
+
+ Returns:
+ str: Formatted string.
+ """
+ head = 'Dataset ' + self.__class__.__name__
+ body = []
+ type_ = 'test' if self.test_mode else 'training'
+ body.append(f'Type: \t{type_}')
+ body.append(f'Seed: \t{self.seed}')
+
+ def ordinal(n):
+ # Copy from https://codegolf.stackexchange.com/a/74047
+ suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
+ return f'{n}{suffix}'
+
+ body.append(
+ f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
+ if self._fully_initialized:
+ body.append(f'Number of samples: \t{self.__len__()}')
+ else:
+ body.append("Haven't been initialized")
+
+ if self.CLASSES is not None:
+ body.append(f'Number of categories: \t{len(self.CLASSES)}')
+ else:
+ body.append('The `CLASSES` meta info is not set.')
+
+ body.append(
+ f'Original dataset type:\t{self.dataset.__class__.__name__}')
+
+ lines = [head] + [' ' * 4 + line for line in body]
+ return '\n'.join(lines)
diff --git a/mmpretrain/datasets/dtd.py b/mmpretrain/datasets/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..034d0b1b444afebfc420eeff7e138072f7d7ee1f
--- /dev/null
+++ b/mmpretrain/datasets/dtd.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mat4py
+from mmengine import get_file_backend
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import DTD_CATEGORIES
+
+
+@DATASETS.register_module()
+class DTD(BaseDataset):
+ """The Describable Texture Dataset (DTD).
+
+ Support the `Describable Texture Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ DTD dataset directory: ::
+
+ dtd
+ ├── images
+ │ ├── banded
+ | | ├──banded_0002.jpg
+ | | ├──banded_0004.jpg
+ | | └── ...
+ │ └── ...
+ ├── imdb
+ │ └── imdb.mat
+ ├── labels
+ | | ├──labels_joint_anno.txt
+ | | ├──test1.txt
+ | | ├──test2.txt
+ | | └── ...
+ │ └── ...
+ └── ....
+
+ Args:
+ data_root (str): The root directory for Describable Texture dataset.
+ split (str, optional): The dataset split, supports "train",
+ "val", "trainval", and "test". Default to "trainval".
+
+ Examples:
+ >>> from mmpretrain.datasets import DTD
+ >>> train_dataset = DTD(data_root='data/dtd', split='trainval')
+ >>> train_dataset
+ Dataset DTD
+ Number of samples: 3760
+ Number of categories: 47
+ Root of dataset: data/dtd
+ >>> test_dataset = DTD(data_root='data/dtd', split='test')
+ >>> test_dataset
+ Dataset DTD
+ Number of samples: 1880
+ Number of categories: 47
+ Root of dataset: data/dtd
+ """ # noqa: E501
+
+ METAINFO = {'classes': DTD_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
+
+ splits = ['train', 'val', 'trainval', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ data_prefix = 'images'
+ test_mode = split == 'test'
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ ann_file = self.backend.join_path('imdb', 'imdb.mat')
+
+ super(DTD, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ data = mat4py.loadmat(self.ann_file)['images']
+ names = data['name']
+ labels = data['class']
+ parts = data['set']
+ num = len(names)
+ assert num == len(labels) == len(parts), 'get error ann file'
+
+ if self.split == 'train':
+ target_set = {1}
+ elif self.split == 'val':
+ target_set = {2}
+ elif self.split == 'test':
+ target_set = {3}
+ else:
+ target_set = {1, 2}
+
+ data_list = []
+ for i in range(num):
+ if parts[i] in target_set:
+ img_name = names[i]
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ gt_label = labels[i] - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/fgvcaircraft.py b/mmpretrain/datasets/fgvcaircraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..696992c06bbf02f097d017a519d42f758ba5f16f
--- /dev/null
+++ b/mmpretrain/datasets/fgvcaircraft.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import FGVCAIRCRAFT_CATEGORIES
+
+
+@DATASETS.register_module()
+class FGVCAircraft(BaseDataset):
+ """The FGVC_Aircraft Dataset.
+
+ Support the `FGVC_Aircraft Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ FGVC_Aircraft dataset directory: ::
+
+ fgvc-aircraft-2013b
+ └── data
+ ├── images
+ │ ├── 1.jpg
+ │ ├── 2.jpg
+ │ └── ...
+ ├── images_variant_train.txt
+ ├── images_variant_test.txt
+ ├── images_variant_trainval.txt
+ ├── images_variant_val.txt
+ ├── variants.txt
+ └── ....
+
+ Args:
+ data_root (str): The root directory for FGVC_Aircraft dataset.
+ split (str, optional): The dataset split, supports "train",
+ "val", "trainval", and "test". Default to "trainval".
+
+ Examples:
+ >>> from mmpretrain.datasets import FGVCAircraft
+ >>> train_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='trainval')
+ >>> train_dataset
+ Dataset FGVCAircraft
+ Number of samples: 6667
+ Number of categories: 100
+ Root of dataset: data/fgvc-aircraft-2013b
+ >>> test_dataset = FGVCAircraft(data_root='data/fgvc-aircraft-2013b', split='test')
+ >>> test_dataset
+ Dataset FGVCAircraft
+ Number of samples: 3333
+ Number of categories: 100
+ Root of dataset: data/fgvc-aircraft-2013b
+ """ # noqa: E501
+
+ METAINFO = {'classes': FGVCAIRCRAFT_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
+
+ splits = ['train', 'val', 'trainval', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ ann_file = self.backend.join_path('data',
+ f'images_variant_{split}.txt')
+ data_prefix = self.backend.join_path('data', 'images')
+ test_mode = split == 'test'
+
+ super(FGVCAircraft, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ test_mode=test_mode,
+ data_prefix=data_prefix,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ pairs = list_from_file(self.ann_file)
+ data_list = []
+ for pair in pairs:
+ pair = pair.split()
+ img_name = pair[0]
+ class_name = ' '.join(pair[1:])
+ img_name = f'{img_name}.jpg'
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ gt_label = self.METAINFO['classes'].index(class_name)
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/flamingo.py b/mmpretrain/datasets/flamingo.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b5745a1437537fccbc304d158a0f0c8d09f032a
--- /dev/null
+++ b/mmpretrain/datasets/flamingo.py
@@ -0,0 +1,295 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from abc import abstractmethod
+from collections import Counter
+from typing import List
+
+import mmengine
+import numpy as np
+from mmengine.dataset import BaseDataset
+from pycocotools.coco import COCO
+
+from mmpretrain.registry import DATASETS
+from .coco_vqa import COCOVQA
+
+
+class FlamingoFewShotMixin:
+ """Flamingo fewshot eval dataset minin.
+
+ Args:
+ num_shots (int): Number of shots to perform evaluation.
+ Defaults to 0.
+ Note: 0 does not mean a strict zero-shot in Flamingo setting.
+ It will use 2 only-text prompt without in context images.
+ num_support_examples (int): Number of support examples to get the
+ few shots from. Defaults to 2048.
+ num_query_examples (int): Number of query examples to perform the
+ final evaluation. Defaults to 5000.
+ incontext_prompt_temp (str): In context prompt template for few shot
+ examples. Defaults to ''.
+ final_prompt_temp (str): Final query prompt template. Defaults to ''.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ num_shots: int = 0,
+ num_support_examples: int = 2048,
+ num_query_examples: int = 5000,
+ incontext_prompt_temp: str = '',
+ final_prompt_temp: str = '',
+ **kwarg):
+ self.num_shots = num_shots
+ self.num_support_examples = num_support_examples
+ self.num_query_examples = num_query_examples
+ self.incontext_prompt_temp = incontext_prompt_temp
+ self.final_prompt_temp = final_prompt_temp
+ super().__init__(**kwarg)
+
+ def get_subset_idx(self, total_num):
+ random_idx = np.random.choice(
+ total_num,
+ self.num_support_examples + self.num_query_examples,
+ replace=False)
+
+ support_idx = random_idx[:self.num_support_examples]
+ query_idx = random_idx[self.num_support_examples:]
+ return support_idx, query_idx
+
+ @abstractmethod
+ def parse_basic_anno(self, anno: dict) -> dict:
+ """Parse basic annotation for support and query set."""
+ pass
+
+ @abstractmethod
+ def parse_fewshot_anno(self, anno: dict, support_list: List) -> dict:
+ """Parse fewshot related annotation for query set with support list."""
+ pass
+
+
+@DATASETS.register_module()
+class FlamingoEvalCOCOVQA(FlamingoFewShotMixin, COCOVQA):
+ """Flamingo few shot VQAv2 dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``.
+ ann_file (str): Annotation file path.
+ question_file (str): Question file path.
+ num_shots (int): Number of shots to perform evaluation.
+ Defaults to 0.
+ Note: 0 does not mean a strict zero-shot in Flamingo setting.
+ It will use 2 only-text prompt without in context images.
+ num_support_examples (int): Number of support examples to get the
+ few shots from. Defaults to 2048.
+ num_query_examples (int): Number of query examples to perform the
+ final evaluation. Defaults to 5000.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ question_file: str,
+ ann_file: str = '',
+ num_shots: int = 0,
+ num_support_examples: int = 2048,
+ num_query_examples: int = 5000,
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ question_file=question_file,
+ ann_file=ann_file,
+ num_shots=num_shots,
+ num_support_examples=num_support_examples,
+ num_query_examples=num_query_examples,
+ **kwarg)
+
+ def parse_basic_anno(self, ann: dict) -> dict:
+ """Parse basic annotation for support and query set.
+
+ Args:
+ anno (dict): Annotation for single example.
+
+ Return:
+ dict: Parsed annotation for single example.
+ """
+ if ann is None:
+ return {}
+
+ answers = [a['answer'] for a in ann['answers']]
+ count = Counter(answers)
+ answer_weight = [i / len(answers) for i in count.values()]
+ answer_info = {
+ 'gt_answer': list(count.keys()),
+ 'gt_answer_weight': answer_weight
+ }
+ return answer_info
+
+ def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
+ """Parse fewshot related annotation for query set with support list.
+
+ Args:
+ anno (dict): Annotation for single example.
+ support_list (List): List of support subset to subsample few shots.
+
+ Return:
+ dict: Parsed annotation for single example.
+ """
+ # prepare n shots examples
+ shots = random.sample(support_list, self.num_shots)
+
+ # append image path for n shots
+ img_path = [shot['img_path'] for shot in shots]
+ img_path.append(query['img_path'])
+ query['img_path'] = img_path
+
+ query['shots'] = [
+ dict(
+ question=item['question'],
+ answer=item['gt_answer'][0],
+ ) for item in shots
+ ]
+ return query
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ questions = mmengine.load(self.question_file)['questions']
+ if self.ann_file:
+ annotations = mmengine.load(self.ann_file)['annotations']
+ assert len(questions) == len(annotations)
+ else:
+ annotations = [None] * len(questions)
+ if self.num_shots > 0:
+ raise ValueError('Unable to construct few-shot examples '
+ 'since no annotation file.')
+
+ # The original VQAv2 annotation file and question file includes
+ # only image id but no image file paths.
+ self.image_index = self._create_image_index()
+
+ num_data = len(questions)
+ support_idx, query_idx = self.get_subset_idx(num_data)
+
+ # prepare support subset
+ if self.num_shots > 0:
+ support_list = []
+ for idx in support_idx:
+ question = questions[idx]
+ ann = annotations[idx]
+ support = {**question, **self.parse_basic_anno(ann)}
+ support['img_path'] = self.image_index[question['image_id']]
+ support_list.append(support)
+
+ # prepare query subset
+ data_list = []
+ for idx in query_idx:
+ question = questions[idx]
+ ann = annotations[idx]
+ data_info = {**question, **self.parse_basic_anno(ann)}
+ data_info['img_path'] = self.image_index[question['image_id']]
+ if self.num_shots > 0:
+ data_info = self.parse_fewshot_anno(data_info, support_list)
+ data_list.append(data_info)
+
+ return data_list
+
+
+@DATASETS.register_module()
+class FlamingoEvalCOCOCaption(FlamingoFewShotMixin, BaseDataset):
+ """Flamingo few shot COCO Caption dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``.
+ ann_file (str): Annotation file path.
+ data_prefix (dict): Prefix for data field. Defaults to
+ ``dict(img_path='')``.
+ num_shots (int): Number of shots to perform evaluation.
+ Defaults to 0.
+ num_support_examples (int): Number of support examples to get the
+ few shots from. Defaults to 2048.
+ num_query_examples (int): Number of query examples to perform the
+ final evaluation. Defaults to 5000.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ ann_file: str,
+ num_shots: int = 0,
+ num_support_examples: int = 2048,
+ num_query_examples: int = 5000,
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ ann_file=ann_file,
+ num_shots=num_shots,
+ num_support_examples=num_support_examples,
+ num_query_examples=num_query_examples,
+ **kwarg)
+
+ def parse_basic_anno(self, ann: dict, coco: COCO) -> dict:
+ """Parse basic annotation for support and query set.
+
+ Args:
+ anno (dict): Annotation for single example.
+ coco (COCO): The coco dataset.
+
+ Return:
+ dict: Parsed annotation for single example.
+ """
+ img_prefix = self.data_prefix['img_path']
+ img = coco.imgs[ann['image_id']]
+ data_info = dict(
+ img_path=mmengine.join_path(img_prefix, img['file_name']),
+ gt_caption=ann['caption'],
+ image_id=ann['image_id'],
+ )
+ return data_info
+
+ def parse_fewshot_anno(self, query: dict, support_list: List) -> dict:
+ """Parse fewshot related annotation for query set with support list.
+
+ Args:
+ query (dict): Annotation for single example.
+ support_list (List): List of support subset to subsample few shots.
+ coco (COCO): The coco dataset.
+
+ Return:
+ dict: Parsed annotation for single example.
+ """
+ # prepare n shots examples
+ shots = random.sample(support_list, self.num_shots)
+
+ # append image path for n shots
+ img_path = [shot['img_path'] for shot in shots]
+ img_path.append(query['img_path'])
+ query['img_path'] = img_path
+
+ query['shots'] = [dict(caption=item['gt_caption']) for item in shots]
+ return query
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ with mmengine.get_local_path(self.ann_file) as ann_file:
+ coco = COCO(ann_file)
+
+ num_data = len(coco.anns)
+ support_idx, query_idx = self.get_subset_idx(num_data)
+ ann_ids = list(coco.anns)
+
+ # prepare support subset
+ if self.num_shots > 0:
+ support_list = []
+ for idx in support_idx:
+ support = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
+ support_list.append(support)
+
+ # prepare query subset
+ query_list = []
+ for idx in query_idx:
+ data_info = self.parse_basic_anno(coco.anns[ann_ids[idx]], coco)
+ if self.num_shots > 0:
+ data_info = self.parse_fewshot_anno(data_info, support_list)
+ query_list.append(data_info)
+
+ return query_list
diff --git a/mmpretrain/datasets/flickr30k_caption.py b/mmpretrain/datasets/flickr30k_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f6841a2c87a0b3eaa3a7abd5b8fda1cb235bc0
--- /dev/null
+++ b/mmpretrain/datasets/flickr30k_caption.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+from mmengine.fileio import get_file_backend
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class Flickr30kCaption(BaseDataset):
+ """Flickr30k Caption dataset. To generate coco-style GT annotation for
+ evaluation, please refer to
+ tools/dataset_converters/convert_flickr30k_ann.py.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str): Annotation file path for training and validation.
+ split (str): 'train', 'val' or 'test'.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self, data_root: str, data_prefix: str, ann_file: str,
+ split: str, **kwarg):
+
+ assert split in ['train', 'val', 'test'], \
+ '`split` must be train, val or test'
+ self.split = split
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ img_prefix = self.data_prefix['img_path']
+ annotations = mmengine.load(self.ann_file)
+ file_backend = get_file_backend(img_prefix)
+
+ data_list = []
+
+ for img in annotations['images']:
+
+ # img_example={
+ # "sentids": [0, 1, 2],
+ # "imgid": 0,
+ # "sentences": [
+ # {"raw": "Two men in green shirts standing in a yard.",
+ # "imgid": 0, "sentid": 0},
+ # {"raw": "A man in a blue shirt standing in a garden.",
+ # "imgid": 0, "sentid": 1},
+ # {"raw": "Two friends enjoy time spent together.",
+ # "imgid": 0, "sentid": 2}
+ # ],
+ # "split": "train",
+ # "filename": "1000092795.jpg"
+ # },
+
+ if img['split'] != self.split:
+ continue
+
+ for sentence in img['sentences']:
+ data_info = {
+ 'image_id': img['imgid'],
+ 'img_path': file_backend.join_path(img_prefix,
+ img['filename']),
+ 'gt_caption': sentence['raw']
+ }
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/flickr30k_retrieval.py b/mmpretrain/datasets/flickr30k_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f43c151b2079b3f72cf620577923efc57987316
--- /dev/null
+++ b/mmpretrain/datasets/flickr30k_retrieval.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import List
+
+import mmengine
+from mmengine import get_file_backend
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class Flickr30kRetrieval(BaseDataset):
+ """Flickr30k Retrieval dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str): Annotation file path for training and validation.
+ split (str): 'train', 'val' or 'test'.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self, data_root: str, data_prefix: str, ann_file: str,
+ split: str, **kwarg):
+
+ assert split in ['train', 'val', 'test'], \
+ '`split` must be train, val or test'
+ self.split = split
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ # get file backend
+ img_prefix = self.data_prefix['img_path']
+ file_backend = get_file_backend(img_prefix)
+
+ annotations = mmengine.load(self.ann_file)
+
+ # mapping img_id to img filename
+ img_dict = OrderedDict()
+ img_idx = 0
+ sentence_idx = 0
+ train_list = []
+ for img in annotations['images']:
+
+ # img_example={
+ # "sentids": [0, 1, 2],
+ # "imgid": 0,
+ # "sentences": [
+ # {"raw": "Two men in green shirts standing in a yard.",
+ # "imgid": 0, "sentid": 0},
+ # {"raw": "A man in a blue shirt standing in a garden.",
+ # "imgid": 0, "sentid": 1},
+ # {"raw": "Two friends enjoy time spent together.",
+ # "imgid": 0, "sentid": 2}
+ # ],
+ # "split": "train",
+ # "filename": "1000092795.jpg"
+ # },
+
+ if img['split'] != self.split:
+ continue
+
+ # create new idx for image
+ train_image = dict(
+ ori_id=img['imgid'],
+ image_id=img_idx, # used for evaluation
+ img_path=file_backend.join_path(img_prefix, img['filename']),
+ text=[],
+ gt_text_id=[],
+ gt_image_id=[],
+ )
+
+ for sentence in img['sentences']:
+ ann = {}
+ ann['text'] = sentence['raw']
+ ann['ori_id'] = sentence['sentid']
+ ann['text_id'] = sentence_idx # used for evaluation
+
+ ann['image_ori_id'] = train_image['ori_id']
+ ann['image_id'] = train_image['image_id']
+ ann['img_path'] = train_image['img_path']
+ ann['is_matched'] = True
+
+ # 1. prepare train data list item
+ train_list.append(ann)
+ # 2. prepare eval data list item based on img dict
+ train_image['text'].append(ann['text'])
+ train_image['gt_text_id'].append(ann['text_id'])
+ train_image['gt_image_id'].append(ann['image_id'])
+
+ sentence_idx += 1
+
+ img_dict[img['imgid']] = train_image
+ img_idx += 1
+
+ self.img_size = len(img_dict)
+ self.text_size = len(train_list)
+
+ # return needed format data list
+ if self.test_mode:
+ return list(img_dict.values())
+ return train_list
diff --git a/mmpretrain/datasets/flowers102.py b/mmpretrain/datasets/flowers102.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe76dcc8422c8692261800b204a6262b60002e81
--- /dev/null
+++ b/mmpretrain/datasets/flowers102.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mat4py
+from mmengine import get_file_backend
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class Flowers102(BaseDataset):
+ """The Oxford 102 Flower Dataset.
+
+ Support the `Oxford 102 Flowers Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ Flowers102 dataset directory: ::
+
+ Flowers102
+ ├── jpg
+ │ ├── image_00001.jpg
+ │ ├── image_00002.jpg
+ │ └── ...
+ ├── imagelabels.mat
+ ├── setid.mat
+ └── ...
+
+ Args:
+ data_root (str): The root directory for Oxford 102 Flowers dataset.
+ split (str, optional): The dataset split, supports "train",
+ "val", "trainval", and "test". Default to "trainval".
+
+ Examples:
+ >>> from mmpretrain.datasets import Flowers102
+ >>> train_dataset = Flowers102(data_root='data/Flowers102', split='trainval')
+ >>> train_dataset
+ Dataset Flowers102
+ Number of samples: 2040
+ Root of dataset: data/Flowers102
+ >>> test_dataset = Flowers102(data_root='data/Flowers102', split='test')
+ >>> test_dataset
+ Dataset Flowers102
+ Number of samples: 6149
+ Root of dataset: data/Flowers102
+ """ # noqa: E501
+
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
+ splits = ['train', 'val', 'trainval', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ ann_file = 'imagelabels.mat'
+ data_prefix = 'jpg'
+ train_test_split_file = 'setid.mat'
+ test_mode = split == 'test'
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+
+ self.train_test_split_file = self.backend.join_path(
+ data_root, train_test_split_file)
+
+ super(Flowers102, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ label_dict = mat4py.loadmat(self.ann_file)['labels']
+ split_list = mat4py.loadmat(self.train_test_split_file)
+
+ if self.split == 'train':
+ split_list = split_list['trnid']
+ elif self.split == 'val':
+ split_list = split_list['valid']
+ elif self.split == 'test':
+ split_list = split_list['tstid']
+ else:
+ train_ids = split_list['trnid']
+ val_ids = split_list['valid']
+ train_ids.extend(val_ids)
+ split_list = train_ids
+
+ data_list = []
+ for sample_id in split_list:
+ img_name = 'image_%05d.jpg' % (sample_id)
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ gt_label = int(label_dict[sample_id - 1]) - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/food101.py b/mmpretrain/datasets/food101.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce7ffeee91c6843c259149770e9de4ad9f4317a
--- /dev/null
+++ b/mmpretrain/datasets/food101.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import FOOD101_CATEGORIES
+
+
+@DATASETS.register_module()
+class Food101(BaseDataset):
+ """The Food101 Dataset.
+
+ Support the `Food101 Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ Food101 dataset directory: ::
+
+ food-101
+ ├── images
+ │ ├── class_x
+ │ │ ├── xx1.jpg
+ │ │ ├── xx2.jpg
+ │ │ └── ...
+ │ ├── class_y
+ │ │ ├── yy1.jpg
+ │ │ ├── yy2.jpg
+ │ │ └── ...
+ │ └── ...
+ ├── meta
+ │ ├── train.txt
+ │ └── test.txt
+ └── ....
+
+ Args:
+ data_root (str): The root directory for Food101 dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+
+ Examples:
+ >>> from mmpretrain.datasets import Food101
+ >>> train_dataset = Food101(data_root='data/food-101', split='train')
+ >>> train_dataset
+ Dataset Food101
+ Number of samples: 75750
+ Number of categories: 101
+ Root of dataset: data/food-101
+ >>> test_dataset = Food101(data_root='data/food-101', split='test')
+ >>> test_dataset
+ Dataset Food101
+ Number of samples: 25250
+ Number of categories: 101
+ Root of dataset: data/food-101
+ """ # noqa: E501
+
+ METAINFO = {'classes': FOOD101_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ if split == 'train':
+ ann_file = self.backend.join_path('meta', 'train.txt')
+ else:
+ ann_file = self.backend.join_path('meta', 'test.txt')
+
+ test_mode = split == 'test'
+ data_prefix = 'images'
+
+ super(Food101, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ test_mode=test_mode,
+ data_prefix=data_prefix,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ pairs = list_from_file(self.ann_file)
+ data_list = []
+ for pair in pairs:
+ class_name, img_name = pair.split('/')
+ img_name = f'{img_name}.jpg'
+ img_path = self.backend.join_path(self.img_prefix, class_name,
+ img_name)
+ gt_label = self.METAINFO['classes'].index(class_name)
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/gqa_dataset.py b/mmpretrain/datasets/gqa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..741791bc2bb51f768e8907aac7f002f0e730aeea
--- /dev/null
+++ b/mmpretrain/datasets/gqa_dataset.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class GQA(BaseDataset):
+ """GQA dataset.
+
+ We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501
+
+ train:
+ https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501
+ val:
+ https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501
+ test:
+ https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501
+
+ and images from the official website:
+ https://cs.stanford.edu/people/dorarad/gqa/index.html
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to an empty string.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str,
+ ann_file: str = '',
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ annotations = mmengine.load(self.ann_file)
+
+ data_list = []
+ for ann in annotations:
+ # ann example
+ # {
+ # 'question': "Is it overcast?",
+ # 'answer': 'no,
+ # 'image_id': n161313.jpg,
+ # 'question_id': 262148000,
+ # ....
+ # }
+ data_info = dict()
+ data_info['img_path'] = osp.join(self.data_prefix['img_path'],
+ ann['image'])
+ data_info['question'] = ann['question']
+ data_info['gt_answer'] = ann['answer']
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/imagenet.py b/mmpretrain/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..771d6ee454e3dc094962ca09036888f97ffb2d21
--- /dev/null
+++ b/mmpretrain/datasets/imagenet.py
@@ -0,0 +1,235 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+from mmengine import fileio
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .categories import IMAGENET_CATEGORIES
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ImageNet(CustomDataset):
+ """`ImageNet `_ Dataset.
+
+ The dataset supports two kinds of directory format,
+
+ ::
+
+ imagenet
+ ├── train
+ │ ├──class_x
+ | | ├── x1.jpg
+ | | ├── x2.jpg
+ | | └── ...
+ │ ├── class_y
+ | | ├── y1.jpg
+ | | ├── y2.jpg
+ | | └── ...
+ | └── ...
+ ├── val
+ │ ├──class_x
+ | | └── ...
+ │ ├── class_y
+ | | └── ...
+ | └── ...
+ └── test
+ ├── test1.jpg
+ ├── test2.jpg
+ └── ...
+
+ or ::
+
+ imagenet
+ ├── train
+ │ ├── x1.jpg
+ │ ├── y1.jpg
+ │ └── ...
+ ├── val
+ │ ├── x3.jpg
+ │ ├── y3.jpg
+ │ └── ...
+ ├── test
+ │ ├── test1.jpg
+ │ ├── test2.jpg
+ │ └── ...
+ └── meta
+ ├── train.txt
+ └── val.txt
+
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ split (str): The dataset split, supports "train", "val" and "test".
+ Default to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
+ ann_file (str): Annotation file path. Defaults to ''.
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
+ :class:`BaseDataset`.
+
+
+ Examples:
+ >>> from mmpretrain.datasets import ImageNet
+ >>> train_dataset = ImageNet(data_root='data/imagenet', split='train')
+ >>> train_dataset
+ Dataset ImageNet
+ Number of samples: 1281167
+ Number of categories: 1000
+ Root of dataset: data/imagenet
+ >>> test_dataset = ImageNet(data_root='data/imagenet', split='val')
+ >>> test_dataset
+ Dataset ImageNet
+ Number of samples: 50000
+ Number of categories: 1000
+ Root of dataset: data/imagenet
+ """ # noqa: E501
+
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
+ METAINFO = {'classes': IMAGENET_CATEGORIES}
+
+ def __init__(self,
+ data_root: str = '',
+ split: str = '',
+ data_prefix: Union[str, dict] = '',
+ ann_file: str = '',
+ metainfo: Optional[dict] = None,
+ **kwargs):
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
+
+ if split:
+ splits = ['train', 'val', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+
+ if split == 'test':
+ logger = MMLogger.get_current_instance()
+ logger.info(
+ 'Since the ImageNet1k test set does not provide label'
+ 'annotations, `with_label` is set to False')
+ kwargs['with_label'] = False
+
+ data_prefix = split if data_prefix == '' else data_prefix
+
+ if ann_file == '':
+ _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt')
+ if fileio.exists(_ann_path):
+ ann_file = fileio.join_path('meta', f'{split}.txt')
+
+ super().__init__(
+ data_root=data_root,
+ data_prefix=data_prefix,
+ ann_file=ann_file,
+ metainfo=metainfo,
+ **kwargs)
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
+
+
+@DATASETS.register_module()
+class ImageNet21k(CustomDataset):
+ """ImageNet21k Dataset.
+
+ Since the dataset ImageNet21k is extremely big, contains 21k+ classes
+ and 1.4B files. We won't provide the default categories list. Please
+ specify it from the ``classes`` argument.
+ The dataset directory structure is as follows,
+
+ ImageNet21k dataset directory ::
+
+ imagenet21k
+ ├── train
+ │ ├──class_x
+ | | ├── x1.jpg
+ | | ├── x2.jpg
+ | | └── ...
+ │ ├── class_y
+ | | ├── y1.jpg
+ | | ├── y2.jpg
+ | | └── ...
+ | └── ...
+ └── meta
+ └── train.txt
+
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
+ ann_file (str): Annotation file path. Defaults to ''.
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ multi_label (bool): Not implement by now. Use multi label or not.
+ Defaults to False.
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
+ :class:`BaseDataset`.
+
+ Examples:
+ >>> from mmpretrain.datasets import ImageNet21k
+ >>> train_dataset = ImageNet21k(data_root='data/imagenet21k', split='train')
+ >>> train_dataset
+ Dataset ImageNet21k
+ Number of samples: 14197088
+ Annotation file: data/imagenet21k/meta/train.txt
+ Prefix of images: data/imagenet21k/train
+ """ # noqa: E501
+
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
+
+ def __init__(self,
+ data_root: str = '',
+ split: str = '',
+ data_prefix: Union[str, dict] = '',
+ ann_file: str = '',
+ metainfo: Optional[dict] = None,
+ multi_label: bool = False,
+ **kwargs):
+ if multi_label:
+ raise NotImplementedError(
+ 'The `multi_label` option is not supported by now.')
+ self.multi_label = multi_label
+
+ if split:
+ splits = ['train']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'.\
+ If you want to specify your own validation set or test set,\
+ please set split to None."
+
+ self.split = split
+ data_prefix = split if data_prefix == '' else data_prefix
+
+ if not ann_file:
+ _ann_path = fileio.join_path(data_root, 'meta', f'{split}.txt')
+ if fileio.exists(_ann_path):
+ ann_file = fileio.join_path('meta', f'{split}.txt')
+
+ logger = MMLogger.get_current_instance()
+
+ if not ann_file:
+ logger.warning(
+ 'The ImageNet21k dataset is large, and scanning directory may '
+ 'consume long time. Considering to specify the `ann_file` to '
+ 'accelerate the initialization.')
+
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
+ super().__init__(
+ data_root=data_root,
+ data_prefix=data_prefix,
+ ann_file=ann_file,
+ metainfo=metainfo,
+ **kwargs)
+
+ if self.CLASSES is None:
+ logger.warning(
+ 'The CLASSES is not stored in the `ImageNet21k` class. '
+ 'Considering to specify the `classes` argument if you need '
+ 'do inference on the ImageNet-21k dataset')
diff --git a/mmpretrain/datasets/inshop.py b/mmpretrain/datasets/inshop.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64f1779632d4a98d0e36d59750f4a1e8cbd4aed
--- /dev/null
+++ b/mmpretrain/datasets/inshop.py
@@ -0,0 +1,157 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class InShop(BaseDataset):
+ """InShop Dataset for Image Retrieval.
+
+ Please download the images from the homepage
+ 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
+ (In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
+ Eval/list_eval_partition.txt), and organize them as follows way: ::
+
+ In-shop Clothes Retrieval Benchmark (data_root)/
+ ├── Eval /
+ │ └── list_eval_partition.txt (ann_file)
+ ├── Img (img_prefix)
+ │ └── img/
+ ├── README.txt
+ └── .....
+
+ Args:
+ data_root (str): The root directory for dataset.
+ split (str): Choose from 'train', 'query' and 'gallery'.
+ Defaults to 'train'.
+ data_prefix (str | dict): Prefix for training data.
+ Defaults to 'Img'.
+ ann_file (str): Annotation file path, path relative to
+ ``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+
+ Examples:
+ >>> from mmpretrain.datasets import InShop
+ >>>
+ >>> # build train InShop dataset
+ >>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
+ >>> inshop_train = InShop(**inshop_train_cfg)
+ >>> inshop_train
+ Dataset InShop
+ Number of samples: 25882
+ The `CLASSES` meta info is not set.
+ Root of dataset: data/inshop
+ >>>
+ >>> # build query InShop dataset
+ >>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
+ >>> inshop_query = InShop(**inshop_query_cfg)
+ >>> inshop_query
+ Dataset InShop
+ Number of samples: 14218
+ The `CLASSES` meta info is not set.
+ Root of dataset: data/inshop
+ >>>
+ >>> # build gallery InShop dataset
+ >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
+ >>> inshop_gallery = InShop(**inshop_gallery_cfg)
+ >>> inshop_gallery
+ Dataset InShop
+ Number of samples: 12612
+ The `CLASSES` meta info is not set.
+ Root of dataset: data/inshop
+ """
+
+ def __init__(self,
+ data_root: str,
+ split: str = 'train',
+ data_prefix: str = 'Img',
+ ann_file: str = 'Eval/list_eval_partition.txt',
+ **kwargs):
+
+ assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
+ f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ self.split = split
+ super().__init__(
+ data_root=data_root,
+ data_prefix=data_prefix,
+ ann_file=ann_file,
+ **kwargs)
+
+ def _process_annotations(self):
+ lines = list_from_file(self.ann_file)
+
+ anno_train = dict(metainfo=dict(), data_list=list())
+ anno_gallery = dict(metainfo=dict(), data_list=list())
+
+ # item_id to label, each item corresponds to one class label
+ class_num = 0
+ gt_label_train = {}
+
+ # item_id to label, each label corresponds to several items
+ gallery_num = 0
+ gt_label_gallery = {}
+
+ # (lines[0], lines[1]) is the image number and the field name;
+ # Each line format as 'image_name, item_id, evaluation_status'
+ for line in lines[2:]:
+ img_name, item_id, status = line.split()
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ if status == 'train':
+ if item_id not in gt_label_train:
+ gt_label_train[item_id] = class_num
+ class_num += 1
+ # item_id to class_id (for the training set)
+ anno_train['data_list'].append(
+ dict(img_path=img_path, gt_label=gt_label_train[item_id]))
+ elif status == 'gallery':
+ if item_id not in gt_label_gallery:
+ gt_label_gallery[item_id] = []
+ # Since there are multiple images for each item,
+ # record the corresponding item for each image.
+ gt_label_gallery[item_id].append(gallery_num)
+ anno_gallery['data_list'].append(
+ dict(img_path=img_path, sample_idx=gallery_num))
+ gallery_num += 1
+
+ if self.split == 'train':
+ anno_train['metainfo']['class_number'] = class_num
+ anno_train['metainfo']['sample_number'] = \
+ len(anno_train['data_list'])
+ return anno_train
+ elif self.split == 'gallery':
+ anno_gallery['metainfo']['sample_number'] = gallery_num
+ return anno_gallery
+
+ # Generate the label for the query(val) set
+ anno_query = dict(metainfo=dict(), data_list=list())
+ query_num = 0
+ for line in lines[2:]:
+ img_name, item_id, status = line.split()
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ if status == 'query':
+ anno_query['data_list'].append(
+ dict(
+ img_path=img_path, gt_label=gt_label_gallery[item_id]))
+ query_num += 1
+
+ anno_query['metainfo']['sample_number'] = query_num
+ return anno_query
+
+ def load_data_list(self):
+ """load data list.
+
+ For the train set, return image and ground truth label. For the query
+ set, return image and ids of images in gallery. For the gallery set,
+ return image and its id.
+ """
+ data_info = self._process_annotations()
+ data_list = data_info['data_list']
+ return data_list
+
+ def extra_repr(self):
+ """The extra repr information of the dataset."""
+ body = [f'Root of dataset: \t{self.data_root}']
+ return body
diff --git a/mmpretrain/datasets/mnist.py b/mmpretrain/datasets/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..425267fe8034860d3b78c6af5b565ddb6efc7c10
--- /dev/null
+++ b/mmpretrain/datasets/mnist.py
@@ -0,0 +1,234 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import codecs
+from typing import List, Optional
+from urllib.parse import urljoin
+
+import mmengine.dist as dist
+import numpy as np
+import torch
+from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
+from .utils import (download_and_extract_archive, open_maybe_compressed_file,
+ rm_suffix)
+
+
+@DATASETS.register_module()
+class MNIST(BaseDataset):
+ """`MNIST `_ Dataset.
+
+ This implementation is modified from
+ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
+
+ Args:
+ data_root (str): The root directory of the MNIST Dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+ metainfo (dict, optional): Meta information for dataset, such as
+ categories information. Defaults to None.
+ download (bool): Whether to download the dataset if not exists.
+ Defaults to True.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """ # noqa: E501
+
+ url_prefix = 'http://yann.lecun.com/exdb/mnist/'
+ # train images and labels
+ train_list = [
+ ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
+ ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
+ ]
+ # test images and labels
+ test_list = [
+ ['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
+ ['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
+ ]
+ METAINFO = {'classes': MNIST_CATEGORITES}
+
+ def __init__(self,
+ data_root: str = '',
+ split: str = 'train',
+ metainfo: Optional[dict] = None,
+ download: bool = True,
+ data_prefix: str = '',
+ test_mode: bool = False,
+ **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ # To handle the BC-breaking
+ if split == 'train' and test_mode:
+ logger = MMLogger.get_current_instance()
+ logger.warning('split="train" but test_mode=True. '
+ 'The training set will be used.')
+
+ if not data_root and not data_prefix:
+ raise RuntimeError('Please set ``data_root`` to'
+ 'specify the dataset path')
+
+ self.download = download
+ super().__init__(
+ # The MNIST dataset doesn't need specify annotation file
+ ann_file='',
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=dict(root=data_prefix),
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+ root = self.data_prefix['root']
+ backend = get_file_backend(root, enable_singleton=True)
+
+ if dist.is_main_process() and not self._check_exists():
+ if not isinstance(backend, LocalBackend):
+ raise RuntimeError(f'The dataset on {root} is not integrated, '
+ f'please manually handle it.')
+
+ if self.download:
+ self._download()
+ else:
+ raise RuntimeError(
+ f'Cannot find {self.__class__.__name__} dataset in '
+ f"{self.data_prefix['root']}, you can specify "
+ '`download=True` to download automatically.')
+
+ dist.barrier()
+ assert self._check_exists(), \
+ 'Download failed or shared storage is unavailable. Please ' \
+ f'download the dataset manually through {self.url_prefix}.'
+
+ if not self.test_mode:
+ file_list = self.train_list
+ else:
+ file_list = self.test_list
+
+ # load data from SN3 files
+ imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
+ gt_labels = read_label_file(
+ join_path(root, rm_suffix(file_list[1][0])))
+
+ data_infos = []
+ for img, gt_label in zip(imgs, gt_labels):
+ gt_label = np.array(gt_label, dtype=np.int64)
+ info = {'img': img.numpy(), 'gt_label': gt_label}
+ data_infos.append(info)
+ return data_infos
+
+ def _check_exists(self):
+ """Check the exists of data files."""
+ root = self.data_prefix['root']
+
+ for filename, _ in (self.train_list + self.test_list):
+ # get extracted filename of data
+ extract_filename = rm_suffix(filename)
+ fpath = join_path(root, extract_filename)
+ if not exists(fpath):
+ return False
+ return True
+
+ def _download(self):
+ """Download and extract data files."""
+ root = self.data_prefix['root']
+
+ for filename, md5 in (self.train_list + self.test_list):
+ url = urljoin(self.url_prefix, filename)
+ download_and_extract_archive(
+ url, download_root=root, filename=filename, md5=md5)
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [f"Prefix of data: \t{self.data_prefix['root']}"]
+ return body
+
+
+@DATASETS.register_module()
+class FashionMNIST(MNIST):
+ """`Fashion-MNIST `_
+ Dataset.
+
+ Args:
+ data_root (str): The root directory of the MNIST Dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+ metainfo (dict, optional): Meta information for dataset, such as
+ categories information. Defaults to None.
+ download (bool): Whether to download the dataset if not exists.
+ Defaults to True.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
+ # train images and labels
+ train_list = [
+ ['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
+ ['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
+ ]
+ # test images and labels
+ test_list = [
+ ['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
+ ['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
+ ]
+ METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
+
+
+def get_int(b: bytes) -> int:
+ """Convert bytes to int."""
+ return int(codecs.encode(b, 'hex'), 16)
+
+
+def read_sn3_pascalvincent_tensor(path: str,
+ strict: bool = True) -> torch.Tensor:
+ """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
+ io.lsh').
+
+ Argument may be a filename, compressed filename, or file object.
+ """
+ # typemap
+ if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
+ read_sn3_pascalvincent_tensor.typemap = {
+ 8: (torch.uint8, np.uint8, np.uint8),
+ 9: (torch.int8, np.int8, np.int8),
+ 11: (torch.int16, np.dtype('>i2'), 'i2'),
+ 12: (torch.int32, np.dtype('>i4'), 'i4'),
+ 13: (torch.float32, np.dtype('>f4'), 'f4'),
+ 14: (torch.float64, np.dtype('>f8'), 'f8')
+ }
+ # read
+ with open_maybe_compressed_file(path) as f:
+ data = f.read()
+ # parse
+ magic = get_int(data[0:4])
+ nd = magic % 256
+ ty = magic // 256
+ assert nd >= 1 and nd <= 3
+ assert ty >= 8 and ty <= 14
+ m = read_sn3_pascalvincent_tensor.typemap[ty]
+ s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
+ parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
+ assert parsed.shape[0] == np.prod(s) or not strict
+ return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
+
+
+def read_label_file(path: str) -> torch.Tensor:
+ """Read labels from SN3 label file."""
+ with open(path, 'rb') as f:
+ x = read_sn3_pascalvincent_tensor(f, strict=False)
+ assert (x.dtype == torch.uint8)
+ assert (x.ndimension() == 1)
+ return x.long()
+
+
+def read_image_file(path: str) -> torch.Tensor:
+ """Read images from SN3 image file."""
+ with open(path, 'rb') as f:
+ x = read_sn3_pascalvincent_tensor(f, strict=False)
+ assert (x.dtype == torch.uint8)
+ assert (x.ndimension() == 3)
+ return x
diff --git a/mmpretrain/datasets/multi_label.py b/mmpretrain/datasets/multi_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..58a9c7cd5f097689d29700004e2ed815934a1594
--- /dev/null
+++ b/mmpretrain/datasets/multi_label.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class MultiLabelDataset(BaseDataset):
+ """Multi-label Dataset.
+
+ This dataset support annotation file in `OpenMMLab 2.0 style annotation
+ format`.
+
+ The annotation format is shown as follows.
+
+ .. code-block:: none
+
+ {
+ "metainfo":
+ {
+ "classes":['A', 'B', 'C'....]
+ },
+ "data_list":
+ [
+ {
+ "img_path": "test_img1.jpg",
+ 'gt_label': [0, 1],
+ },
+ {
+ "img_path": "test_img2.jpg",
+ 'gt_label': [2],
+ },
+ ]
+ ....
+ }
+
+
+ Args:
+ ann_file (str): Annotation file path.
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Defaults to None which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy. Defaults
+ to True.
+ pipeline (list, optional): Processing pipeline. Defaults to [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Defaults to False.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=False``. Defaults to False.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Defaults to 1000.
+ classes (str | Sequence[str], optional): Specify names of classes.
+
+ - If is string, it should be a file path, and the every line of
+ the file is a name of a class.
+ - If is a sequence of string, every item is a name of class.
+ - If is None, use categories information in ``metainfo`` argument,
+ annotation file or the class attribute ``METAINFO``.
+
+ Defaults to None.
+ """
+
+ def get_cat_ids(self, idx: int) -> List[int]:
+ """Get category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ cat_ids (List[int]): Image categories of specified index.
+ """
+ return self.get_data_info(idx)['gt_label']
diff --git a/mmpretrain/datasets/multi_task.py b/mmpretrain/datasets/multi_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..443df0e7d7de11962d472d33b25b4bbff562524f
--- /dev/null
+++ b/mmpretrain/datasets/multi_task.py
@@ -0,0 +1,337 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os.path as osp
+from os import PathLike
+from typing import Optional, Sequence
+
+import mmengine
+from mmcv.transforms import Compose
+from mmengine.fileio import get_file_backend
+
+from .builder import DATASETS
+
+
+def expanduser(path):
+ if isinstance(path, (str, PathLike)):
+ return osp.expanduser(path)
+ else:
+ return path
+
+
+def isabs(uri):
+ return osp.isabs(uri) or ('://' in uri)
+
+
+@DATASETS.register_module()
+class MultiTaskDataset:
+ """Custom dataset for multi-task dataset.
+
+ To use the dataset, please generate and provide an annotation file in the
+ below format:
+
+ .. code-block:: json
+
+ {
+ "metainfo": {
+ "tasks":
+ [
+ 'gender'
+ 'wear'
+ ]
+ },
+ "data_list": [
+ {
+ "img_path": "a.jpg",
+ gt_label:{
+ "gender": 0,
+ "wear": [1, 0, 1, 0]
+ }
+ },
+ {
+ "img_path": "b.jpg",
+ gt_label:{
+ "gender": 1,
+ "wear": [1, 0, 1, 0]
+ }
+ }
+ ]
+ }
+
+ Assume we put our dataset in the ``data/mydataset`` folder in the
+ repository and organize it as the below format: ::
+
+ mmpretrain/
+ └── data
+ └── mydataset
+ ├── annotation
+ │ ├── train.json
+ │ ├── test.json
+ │ └── val.json
+ ├── train
+ │ ├── a.jpg
+ │ └── ...
+ ├── test
+ │ ├── b.jpg
+ │ └── ...
+ └── val
+ ├── c.jpg
+ └── ...
+
+ We can use the below config to build datasets:
+
+ .. code:: python
+
+ >>> from mmpretrain.datasets import build_dataset
+ >>> train_cfg = dict(
+ ... type="MultiTaskDataset",
+ ... ann_file="annotation/train.json",
+ ... data_root="data/mydataset",
+ ... # The `img_path` field in the train annotation file is relative
+ ... # to the `train` folder.
+ ... data_prefix='train',
+ ... )
+ >>> train_dataset = build_dataset(train_cfg)
+
+ Or we can put all files in the same folder: ::
+
+ mmpretrain/
+ └── data
+ └── mydataset
+ ├── train.json
+ ├── test.json
+ ├── val.json
+ ├── a.jpg
+ ├── b.jpg
+ ├── c.jpg
+ └── ...
+
+ And we can use the below config to build datasets:
+
+ .. code:: python
+
+ >>> from mmpretrain.datasets import build_dataset
+ >>> train_cfg = dict(
+ ... type="MultiTaskDataset",
+ ... ann_file="train.json",
+ ... data_root="data/mydataset",
+ ... # the `data_prefix` is not required since all paths are
+ ... # relative to the `data_root`.
+ ... )
+ >>> train_dataset = build_dataset(train_cfg)
+
+
+ Args:
+ ann_file (str): The annotation file path. It can be either absolute
+ path or relative path to the ``data_root``.
+ metainfo (dict, optional): The extra meta information. It should be
+ a dict with the same format as the ``"metainfo"`` field in the
+ annotation file. Defaults to None.
+ data_root (str, optional): The root path of the data directory. It's
+ the prefix of the ``data_prefix`` and the ``ann_file``. And it can
+ be a remote path like "s3://openmmlab/xxx/". Defaults to None.
+ data_prefix (str, optional): The base folder relative to the
+ ``data_root`` for the ``"img_path"`` field in the annotation file.
+ Defaults to None.
+ pipeline (Sequence[dict]): A list of dict, where each element
+ represents a operation defined in
+ :mod:`mmpretrain.datasets.pipelines`. Defaults to an empty tuple.
+ test_mode (bool): in train mode or test mode. Defaults to False.
+ """
+ METAINFO = dict()
+
+ def __init__(self,
+ ann_file: str,
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: Optional[str] = None,
+ pipeline: Sequence = (),
+ test_mode: bool = False):
+
+ self.data_root = expanduser(data_root)
+
+ # Inference the file client
+ if self.data_root is not None:
+ self.file_backend = get_file_backend(uri=self.data_root)
+ else:
+ self.file_backend = None
+
+ self.ann_file = self._join_root(expanduser(ann_file))
+ self.data_prefix = self._join_root(data_prefix)
+
+ self.test_mode = test_mode
+ self.pipeline = Compose(pipeline)
+ self.data_list = self.load_data_list(self.ann_file, metainfo)
+
+ def _join_root(self, path):
+ """Join ``self.data_root`` with the specified path.
+
+ If the path is an absolute path, just return the path. And if the
+ path is None, return ``self.data_root``.
+
+ Examples:
+ >>> self.data_root = 'a/b/c'
+ >>> self._join_root('d/e/')
+ 'a/b/c/d/e'
+ >>> self._join_root('https://openmmlab.com')
+ 'https://openmmlab.com'
+ >>> self._join_root(None)
+ 'a/b/c'
+ """
+ if path is None:
+ return self.data_root
+ if isabs(path):
+ return path
+
+ joined_path = self.file_backend.join_path(self.data_root, path)
+ return joined_path
+
+ @classmethod
+ def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
+ """Collect meta information from the dictionary of meta.
+
+ Args:
+ in_metainfo (dict): Meta information dict.
+
+ Returns:
+ dict: Parsed meta information.
+ """
+ # `cls.METAINFO` will be overwritten by in_meta
+ metainfo = copy.deepcopy(cls.METAINFO)
+ if in_metainfo is None:
+ return metainfo
+
+ metainfo.update(in_metainfo)
+
+ return metainfo
+
+ def load_data_list(self, ann_file, metainfo_override=None):
+ """Load annotations from an annotation file.
+
+ Args:
+ ann_file (str): Absolute annotation file path if ``self.root=None``
+ or relative path if ``self.root=/path/to/data/``.
+
+ Returns:
+ list[dict]: A list of annotation.
+ """
+ annotations = mmengine.load(ann_file)
+ if not isinstance(annotations, dict):
+ raise TypeError(f'The annotations loaded from annotation file '
+ f'should be a dict, but got {type(annotations)}!')
+ if 'data_list' not in annotations:
+ raise ValueError('The annotation file must have the `data_list` '
+ 'field.')
+ metainfo = annotations.get('metainfo', {})
+ raw_data_list = annotations['data_list']
+
+ # Set meta information.
+ assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
+ f'annotation file should be a dict, but got {type(metainfo)}'
+ if metainfo_override is not None:
+ assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
+ f'argument should be a dict, but got {type(metainfo_override)}'
+ metainfo.update(metainfo_override)
+ self._metainfo = self._get_meta_info(metainfo)
+
+ data_list = []
+ for i, raw_data in enumerate(raw_data_list):
+ try:
+ data_list.append(self.parse_data_info(raw_data))
+ except AssertionError as e:
+ raise RuntimeError(
+ f'The format check fails during parse the item {i} of '
+ f'the annotation file with error: {e}')
+ return data_list
+
+ def parse_data_info(self, raw_data):
+ """Parse raw annotation to target format.
+
+ This method will return a dict which contains the data information of a
+ sample.
+
+ Args:
+ raw_data (dict): Raw data information load from ``ann_file``
+
+ Returns:
+ dict: Parsed annotation.
+ """
+ assert isinstance(raw_data, dict), \
+ f'The item should be a dict, but got {type(raw_data)}'
+ assert 'img_path' in raw_data, \
+ "The item doesn't have `img_path` field."
+ data = dict(
+ img_path=self._join_root(raw_data['img_path']),
+ gt_label=raw_data['gt_label'],
+ )
+ return data
+
+ @property
+ def metainfo(self) -> dict:
+ """Get meta information of dataset.
+
+ Returns:
+ dict: meta information collected from ``cls.METAINFO``,
+ annotation file and metainfo argument during instantiation.
+ """
+ return copy.deepcopy(self._metainfo)
+
+ def prepare_data(self, idx):
+ """Get data processed by ``self.pipeline``.
+
+ Args:
+ idx (int): The index of ``data_info``.
+
+ Returns:
+ Any: Depends on ``self.pipeline``.
+ """
+ results = copy.deepcopy(self.data_list[idx])
+ return self.pipeline(results)
+
+ def __len__(self):
+ """Get the length of the whole dataset.
+
+ Returns:
+ int: The length of filtered dataset.
+ """
+ return len(self.data_list)
+
+ def __getitem__(self, idx):
+ """Get the idx-th image and data information of dataset after
+ ``self.pipeline``.
+
+ Args:
+ idx (int): The index of of the data.
+
+ Returns:
+ dict: The idx-th image and data information after
+ ``self.pipeline``.
+ """
+ return self.prepare_data(idx)
+
+ def __repr__(self):
+ """Print the basic information of the dataset.
+
+ Returns:
+ str: Formatted string.
+ """
+ head = 'Dataset ' + self.__class__.__name__
+ body = [f'Number of samples: \t{self.__len__()}']
+ if self.data_root is not None:
+ body.append(f'Root location: \t{self.data_root}')
+ body.append(f'Annotation file: \t{self.ann_file}')
+ if self.data_prefix is not None:
+ body.append(f'Prefix of images: \t{self.data_prefix}')
+ # -------------------- extra repr --------------------
+ tasks = self.metainfo['tasks']
+ body.append(f'For {len(tasks)} tasks')
+ for task in tasks:
+ body.append(f' {task} ')
+ # ----------------------------------------------------
+
+ if len(self.pipeline.transforms) > 0:
+ body.append('With transforms:')
+ for t in self.pipeline.transforms:
+ body.append(f' {t}')
+
+ lines = [head] + [' ' * 4 + line for line in body]
+ return '\n'.join(lines)
diff --git a/mmpretrain/datasets/nlvr2.py b/mmpretrain/datasets/nlvr2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0063090657714406049a6daa6fa3c0d868422590
--- /dev/null
+++ b/mmpretrain/datasets/nlvr2.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+from typing import List
+
+from mmengine.fileio import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class NLVR2(BaseDataset):
+ """COCO Caption dataset."""
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+
+ data_list = []
+ img_prefix = self.data_prefix['img_path']
+ file_backend = get_file_backend(img_prefix)
+ examples = list_from_file(self.ann_file)
+
+ for example in examples:
+ example = json.loads(example)
+ prefix = example['identifier'].rsplit('-', 1)[0]
+ train_data = {}
+ train_data['text'] = example['sentence']
+ train_data['gt_label'] = {'True': 1, 'False': 0}[example['label']]
+ train_data['img_path'] = [
+ file_backend.join_path(img_prefix, prefix + f'-img{i}.png')
+ for i in range(2)
+ ]
+
+ data_list.append(train_data)
+
+ return data_list
diff --git a/mmpretrain/datasets/nocaps.py b/mmpretrain/datasets/nocaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..65116e9cecc2d9983ef72ca3eee24ff7baedacc0
--- /dev/null
+++ b/mmpretrain/datasets/nocaps.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+from mmengine.fileio import get_file_backend
+from pycocotools.coco import COCO
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class NoCaps(BaseDataset):
+ """NoCaps dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``..
+ ann_file (str): Annotation file path.
+ data_prefix (dict): Prefix for data field. Defaults to
+ ``dict(img_path='')``.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ img_prefix = self.data_prefix['img_path']
+ with mmengine.get_local_path(self.ann_file) as ann_file:
+ coco = COCO(ann_file)
+
+ file_backend = get_file_backend(img_prefix)
+ data_list = []
+ for ann in coco.anns.values():
+ image_id = ann['image_id']
+ image_path = file_backend.join_path(
+ img_prefix, coco.imgs[image_id]['file_name'])
+ data_info = {
+ 'image_id': image_id,
+ 'img_path': image_path,
+ 'gt_caption': None
+ }
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/ocr_vqa.py b/mmpretrain/datasets/ocr_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..55aa6913e3c4464444e8b971ccabf68aa2d99904
--- /dev/null
+++ b/mmpretrain/datasets/ocr_vqa.py
@@ -0,0 +1,91 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class OCRVQA(BaseDataset):
+ """OCR-VQA dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str): Annotation file path for training and validation.
+ split (str): 'train', 'val' or 'test'.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self, data_root: str, data_prefix: str, ann_file: str,
+ split: str, **kwarg):
+
+ assert split in ['train', 'val', 'test'], \
+ '`split` must be train, val or test'
+ self.split = split
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+
+ split_dict = {1: 'train', 2: 'val', 3: 'test'}
+
+ annotations = mmengine.load(self.ann_file)
+
+ # ann example
+ # "761183272": {
+ # "imageURL": \
+ # "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
+ # "questions": [
+ # "Who wrote this book?",
+ # "What is the title of this book?",
+ # "What is the genre of this book?",
+ # "Is this a games related book?",
+ # "What is the year printed on this calendar?"],
+ # "answers": [
+ # "Sandra Boynton",
+ # "Mom's Family Wall Calendar 2016",
+ # "Calendars",
+ # "No",
+ # "2016"],
+ # "title": "Mom's Family Wall Calendar 2016",
+ # "authorName": "Sandra Boynton",
+ # "genre": "Calendars",
+ # "split": 1
+ # },
+
+ data_list = []
+
+ for key, ann in annotations.items():
+ if self.split != split_dict[ann['split']]:
+ continue
+
+ extension = osp.splitext(ann['imageURL'])[1]
+ if extension not in ['.jpg', '.png']:
+ continue
+ img_path = mmengine.join_path(self.data_prefix['img_path'],
+ key + extension)
+ for question, answer in zip(ann['questions'], ann['answers']):
+ data_info = {}
+ data_info['img_path'] = img_path
+ data_info['question'] = question
+ data_info['gt_answer'] = answer
+ data_info['gt_answer_weight'] = [1.0]
+
+ data_info['imageURL'] = ann['imageURL']
+ data_info['title'] = ann['title']
+ data_info['authorName'] = ann['authorName']
+ data_info['genre'] = ann['genre']
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/oxfordiiitpet.py b/mmpretrain/datasets/oxfordiiitpet.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c8b7db8679e99c6ed2698b9eb140cd6151d445
--- /dev/null
+++ b/mmpretrain/datasets/oxfordiiitpet.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import OxfordIIITPet_CATEGORIES
+
+
+@DATASETS.register_module()
+class OxfordIIITPet(BaseDataset):
+ """The Oxford-IIIT Pets Dataset.
+
+ Support the `Oxford-IIIT Pets Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ Oxford-IIIT_Pets dataset directory: ::
+
+ Oxford-IIIT_Pets
+ ├── images
+ │ ├── Abyssinian_1.jpg
+ │ ├── Abyssinian_2.jpg
+ │ └── ...
+ ├── annotations
+ │ ├── trainval.txt
+ │ ├── test.txt
+ │ ├── list.txt
+ │ └── ...
+ └── ....
+
+ Args:
+ data_root (str): The root directory for Oxford-IIIT Pets dataset.
+ split (str, optional): The dataset split, supports "trainval" and "test".
+ Default to "trainval".
+
+ Examples:
+ >>> from mmpretrain.datasets import OxfordIIITPet
+ >>> train_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='trainval')
+ >>> train_dataset
+ Dataset OxfordIIITPet
+ Number of samples: 3680
+ Number of categories: 37
+ Root of dataset: data/Oxford-IIIT_Pets
+ >>> test_dataset = OxfordIIITPet(data_root='data/Oxford-IIIT_Pets', split='test')
+ >>> test_dataset
+ Dataset OxfordIIITPet
+ Number of samples: 3669
+ Number of categories: 37
+ Root of dataset: data/Oxford-IIIT_Pets
+ """ # noqa: E501
+
+ METAINFO = {'classes': OxfordIIITPet_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'trainval', **kwargs):
+
+ splits = ['trainval', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ if split == 'trainval':
+ ann_file = self.backend.join_path('annotations', 'trainval.txt')
+ else:
+ ann_file = self.backend.join_path('annotations', 'test.txt')
+
+ data_prefix = 'images'
+ test_mode = split == 'test'
+
+ super(OxfordIIITPet, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+
+ pairs = list_from_file(self.ann_file)
+ data_list = []
+ for pair in pairs:
+ img_name, class_id, _, _ = pair.split()
+ img_name = f'{img_name}.jpg'
+ img_path = self.backend.join_path(self.img_prefix, img_name)
+ gt_label = int(class_id) - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/places205.py b/mmpretrain/datasets/places205.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ba1ff631a7a4840b66cf63ec53585ec064560d
--- /dev/null
+++ b/mmpretrain/datasets/places205.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Union
+
+from mmpretrain.registry import DATASETS
+from .categories import PLACES205_CATEGORIES
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class Places205(CustomDataset):
+ """`Places205 `_ Dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str | dict): Prefix for training data. Defaults
+ to ''.
+ ann_file (str): Annotation file path. Defaults to ''.
+ metainfo (dict, optional): Meta information for dataset, such as class
+ information. Defaults to None.
+ **kwargs: Other keyword arguments in :class:`CustomDataset` and
+ :class:`BaseDataset`.
+ """
+
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
+ METAINFO = {'classes': PLACES205_CATEGORIES}
+
+ def __init__(self,
+ data_root: str = '',
+ data_prefix: Union[str, dict] = '',
+ ann_file: str = '',
+ metainfo: Optional[dict] = None,
+ **kwargs):
+ kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
+ super().__init__(
+ data_root=data_root,
+ data_prefix=data_prefix,
+ ann_file=ann_file,
+ metainfo=metainfo,
+ **kwargs)
diff --git a/mmpretrain/datasets/refcoco.py b/mmpretrain/datasets/refcoco.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4f2a943f73fdab493a47bbcd1d0ea6385ec60fa
--- /dev/null
+++ b/mmpretrain/datasets/refcoco.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import List
+
+import mmengine
+import numpy as np
+from mmengine.dataset import BaseDataset
+from pycocotools.coco import COCO
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class RefCOCO(BaseDataset):
+ """RefCOCO dataset.
+
+ Args:
+ ann_file (str): Annotation file path.
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to ''.
+ data_prefix (str): Prefix for training data.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root,
+ ann_file,
+ data_prefix,
+ split_file,
+ split='train',
+ **kwargs):
+ self.split_file = split_file
+ self.split = split
+
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwargs,
+ )
+
+ def _join_prefix(self):
+ if not mmengine.is_abs(self.split_file) and self.split_file:
+ self.split_file = osp.join(self.data_root, self.split_file)
+
+ return super()._join_prefix()
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ with mmengine.get_local_path(self.ann_file) as ann_file:
+ coco = COCO(ann_file)
+ splits = mmengine.load(self.split_file, file_format='pkl')
+ img_prefix = self.data_prefix['img_path']
+
+ data_list = []
+ join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
+ for refer in splits:
+ if refer['split'] != self.split:
+ continue
+
+ ann = coco.anns[refer['ann_id']]
+ img = coco.imgs[ann['image_id']]
+ sentences = refer['sentences']
+ bbox = np.array(ann['bbox'], dtype=np.float32)
+ bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY
+
+ for sent in sentences:
+ data_info = {
+ 'img_path': join_path(img_prefix, img['file_name']),
+ 'image_id': ann['image_id'],
+ 'ann_id': ann['id'],
+ 'text': sent['sent'],
+ 'gt_bboxes': bbox[None, :],
+ }
+ data_list.append(data_info)
+
+ if len(data_list) == 0:
+ raise ValueError(f'No sample in split "{self.split}".')
+
+ return data_list
diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bccf9c34659e19764871a696260cf5884696ca1
--- /dev/null
+++ b/mmpretrain/datasets/samplers/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .repeat_aug import RepeatAugSampler
+from .sequential import SequentialSampler
+
+__all__ = ['RepeatAugSampler', 'SequentialSampler']
diff --git a/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9672735d3373a8d9e4d0859cd30310e799556d05
Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de91371076f2d2a1df328f5b24cf9c3aea945235
Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/repeat_aug.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc b/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c150cc6be8d91701e24d7e2662e5567108004164
Binary files /dev/null and b/mmpretrain/datasets/samplers/__pycache__/sequential.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/samplers/repeat_aug.py b/mmpretrain/datasets/samplers/repeat_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..d833a1954d7d9d181c368d5b3b956c25df241c1a
--- /dev/null
+++ b/mmpretrain/datasets/samplers/repeat_aug.py
@@ -0,0 +1,101 @@
+import math
+from typing import Iterator, Optional, Sized
+
+import torch
+from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
+from torch.utils.data import Sampler
+
+from mmpretrain.registry import DATA_SAMPLERS
+
+
+@DATA_SAMPLERS.register_module()
+class RepeatAugSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset for
+ distributed, with repeated augmentation. It ensures that different each
+ augmented version of a sample will be visible to a different process (GPU).
+ Heavily based on torch.utils.data.DistributedSampler.
+
+ This sampler was taken from
+ https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
+ Used in
+ Copyright (c) 2015-present, Facebook, Inc.
+
+ Args:
+ dataset (Sized): The dataset.
+ shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
+ num_repeats (int): The repeat times of every sample. Defaults to 3.
+ seed (int, optional): Random seed used to shuffle the sampler if
+ :attr:`shuffle=True`. This number should be identical across all
+ processes in the distributed group. Defaults to None.
+ """
+
+ def __init__(self,
+ dataset: Sized,
+ shuffle: bool = True,
+ num_repeats: int = 3,
+ seed: Optional[int] = None):
+ rank, world_size = get_dist_info()
+ self.rank = rank
+ self.world_size = world_size
+
+ self.dataset = dataset
+ self.shuffle = shuffle
+ if not self.shuffle and is_main_process():
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ logger.warning('The RepeatAugSampler always picks a '
+ 'fixed part of data if `shuffle=False`.')
+
+ if seed is None:
+ seed = sync_random_seed()
+ self.seed = seed
+ self.epoch = 0
+ self.num_repeats = num_repeats
+
+ # The number of repeated samples in the rank
+ self.num_samples = math.ceil(
+ len(self.dataset) * num_repeats / world_size)
+ # The total number of repeated samples in all ranks.
+ self.total_size = self.num_samples * world_size
+ # The number of selected samples in the rank
+ self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate the indices."""
+ # deterministically shuffle based on epoch and seed
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
+ indices = [x for x in indices for _ in range(self.num_repeats)]
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ indices += indices[:padding_size]
+ assert len(indices) == self.total_size
+
+ # subsample per rank
+ indices = indices[self.rank:self.total_size:self.world_size]
+ assert len(indices) == self.num_samples
+
+ # return up to num selected samples
+ return iter(indices[:self.num_selected_samples])
+
+ def __len__(self) -> int:
+ """The number of samples in this rank."""
+ return self.num_selected_samples
+
+ def set_epoch(self, epoch: int) -> None:
+ """Sets the epoch for this sampler.
+
+ When :attr:`shuffle=True`, this ensures all replicas use a different
+ random ordering for each epoch. Otherwise, the next iteration of this
+ sampler will yield the same ordering.
+
+ Args:
+ epoch (int): Epoch number.
+ """
+ self.epoch = epoch
diff --git a/mmpretrain/datasets/samplers/sequential.py b/mmpretrain/datasets/samplers/sequential.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b940c2eabc2ab9c2401cd1923776fc067e9f6c
--- /dev/null
+++ b/mmpretrain/datasets/samplers/sequential.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Iterator
+
+import torch
+from mmengine.dataset import DefaultSampler
+
+from mmpretrain.registry import DATA_SAMPLERS
+
+
+@DATA_SAMPLERS.register_module()
+class SequentialSampler(DefaultSampler):
+ """Sequential sampler which supports different subsample policy.
+
+ Args:
+ dataset (Sized): The dataset.
+ round_up (bool): Whether to add extra samples to make the number of
+ samples evenly divisible by the world size. Defaults to True.
+ subsample_type (str): The method to subsample data on different rank.
+ Supported type:
+
+ - ``'default'``: Original torch behavior. Sample the examples one
+ by one for each GPU in terms. For instance, 8 examples on 2 GPUs,
+ GPU0: [0,2,4,8], GPU1: [1,3,5,7]
+ - ``'sequential'``: Subsample all examples to n chunk sequntially.
+ For instance, 8 examples on 2 GPUs,
+ GPU0: [0,1,2,3], GPU1: [4,5,6,7]
+ """
+
+ def __init__(self, subsample_type: str = 'default', **kwargs) -> None:
+ super().__init__(shuffle=False, **kwargs)
+
+ if subsample_type not in ['default', 'sequential']:
+ raise ValueError(f'Unsupported subsample typer "{subsample_type}",'
+ ' please choose from ["default", "sequential"]')
+ self.subsample_type = subsample_type
+
+ def __iter__(self) -> Iterator[int]:
+ """Iterate the indices."""
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices = (
+ indices *
+ int(self.total_size / len(indices) + 1))[:self.total_size]
+
+ # subsample
+ if self.subsample_type == 'default':
+ indices = indices[self.rank:self.total_size:self.world_size]
+ elif self.subsample_type == 'sequential':
+ num_samples_per_rank = self.total_size // self.world_size
+ indices = indices[self.rank *
+ num_samples_per_rank:(self.rank + 1) *
+ num_samples_per_rank]
+
+ return iter(indices)
diff --git a/mmpretrain/datasets/scienceqa.py b/mmpretrain/datasets/scienceqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e442491be85540980c0309b65d32a12c9c85542
--- /dev/null
+++ b/mmpretrain/datasets/scienceqa.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import Callable, List, Sequence
+
+import mmengine
+from mmengine.dataset import BaseDataset
+from mmengine.fileio import get_file_backend
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ScienceQA(BaseDataset):
+ """ScienceQA dataset.
+
+ This dataset is used to load the multimodal data of ScienceQA dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix`` and
+ ``ann_file``.
+ split (str): The split of dataset. Options: ``train``, ``val``,
+ ``test``, ``trainval``, ``minival``, and ``minitest``.
+ split_file (str): The split file of dataset, which contains the
+ ids of data samples in the split.
+ ann_file (str): Annotation file path.
+ image_only (bool): Whether only to load data with image. Defaults to
+ False.
+ data_prefix (dict): Prefix for data field. Defaults to
+ ``dict(img_path='')``.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ split: str,
+ split_file: str,
+ ann_file: str,
+ image_only: bool = False,
+ data_prefix: dict = dict(img_path=''),
+ pipeline: Sequence[Callable] = (),
+ **kwargs):
+ assert split in [
+ 'train', 'val', 'test', 'trainval', 'minival', 'minitest'
+ ], f'Invalid split {split}'
+ self.split = split
+ self.split_file = os.path.join(data_root, split_file)
+ self.image_only = image_only
+
+ super().__init__(
+ data_root=data_root,
+ ann_file=ann_file,
+ data_prefix=data_prefix,
+ pipeline=pipeline,
+ **kwargs)
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ img_prefix = self.data_prefix['img_path']
+ annotations = mmengine.load(self.ann_file)
+ current_data_split = mmengine.load(self.split_file)[self.split] # noqa
+
+ file_backend = get_file_backend(img_prefix)
+
+ data_list = []
+ for data_id in current_data_split:
+ ann = annotations[data_id]
+ if self.image_only and ann['image'] is None:
+ continue
+ data_info = {
+ 'image_id':
+ data_id,
+ 'question':
+ ann['question'],
+ 'choices':
+ ann['choices'],
+ 'gt_answer':
+ ann['answer'],
+ 'hint':
+ ann['hint'],
+ 'image_name':
+ ann['image'],
+ 'task':
+ ann['task'],
+ 'grade':
+ ann['grade'],
+ 'subject':
+ ann['subject'],
+ 'topic':
+ ann['topic'],
+ 'category':
+ ann['category'],
+ 'skill':
+ ann['skill'],
+ 'lecture':
+ ann['lecture'],
+ 'solution':
+ ann['solution'],
+ 'split':
+ ann['split'],
+ 'img_path':
+ file_backend.join_path(img_prefix, data_id, ann['image'])
+ if ann['image'] is not None else None,
+ 'has_image':
+ True if ann['image'] is not None else False,
+ }
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/stanfordcars.py b/mmpretrain/datasets/stanfordcars.py
new file mode 100644
index 0000000000000000000000000000000000000000..355697943cf693869f35f2a0bd71abdfa0396722
--- /dev/null
+++ b/mmpretrain/datasets/stanfordcars.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mat4py
+from mmengine import get_file_backend
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import STANFORDCARS_CATEGORIES
+
+
+@DATASETS.register_module()
+class StanfordCars(BaseDataset):
+ """The Stanford Cars Dataset.
+
+ Support the `Stanford Cars Dataset `_ Dataset.
+ The official website provides two ways to organize the dataset.
+ Therefore, after downloading and decompression, the dataset directory structure is as follows.
+
+ Stanford Cars dataset directory: ::
+
+ Stanford_Cars
+ ├── car_ims
+ │ ├── 00001.jpg
+ │ ├── 00002.jpg
+ │ └── ...
+ └── cars_annos.mat
+
+ or ::
+
+ Stanford_Cars
+ ├── cars_train
+ │ ├── 00001.jpg
+ │ ├── 00002.jpg
+ │ └── ...
+ ├── cars_test
+ │ ├── 00001.jpg
+ │ ├── 00002.jpg
+ │ └── ...
+ └── devkit
+ ├── cars_meta.mat
+ ├── cars_train_annos.mat
+ ├── cars_test_annos.mat
+ ├── cars_test_annoswithlabels.mat
+ ├── eval_train.m
+ └── train_perfect_preds.txt
+
+ Args:
+ data_root (str): The root directory for Stanford Cars dataset.
+ split (str, optional): The dataset split, supports "train"
+ and "test". Default to "train".
+
+ Examples:
+ >>> from mmpretrain.datasets import StanfordCars
+ >>> train_dataset = StanfordCars(data_root='data/Stanford_Cars', split='train')
+ >>> train_dataset
+ Dataset StanfordCars
+ Number of samples: 8144
+ Number of categories: 196
+ Root of dataset: data/Stanford_Cars
+ >>> test_dataset = StanfordCars(data_root='data/Stanford_Cars', split='test')
+ >>> test_dataset
+ Dataset StanfordCars
+ Number of samples: 8041
+ Number of categories: 196
+ Root of dataset: data/Stanford_Cars
+ """ # noqa: E501
+
+ METAINFO = {'classes': STANFORDCARS_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ test_mode = split == 'test'
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+
+ anno_file_path = self.backend.join_path(data_root, 'cars_annos.mat')
+ if self.backend.exists(anno_file_path):
+ ann_file = 'cars_annos.mat'
+ data_prefix = ''
+ else:
+ if test_mode:
+ ann_file = self.backend.join_path(
+ 'devkit', 'cars_test_annos_withlabels.mat')
+ data_prefix = 'cars_test'
+ else:
+ ann_file = self.backend.join_path('devkit',
+ 'cars_train_annos.mat')
+ data_prefix = 'cars_train'
+
+ if not self.backend.exists(
+ self.backend.join_path(data_root, ann_file)):
+ doc_url = 'https://mmpretrain.readthedocs.io/en/latest/api/datasets.html#stanfordcars' # noqa: E501
+ raise RuntimeError(
+ f'The dataset is incorrectly organized, please \
+ refer to {doc_url} and reorganize your folders.')
+
+ super(StanfordCars, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ def load_data_list(self):
+ data = mat4py.loadmat(self.ann_file)['annotations']
+
+ data_list = []
+ if 'test' in data.keys():
+ # first way
+ img_paths, labels, test = data['relative_im_path'], data[
+ 'class'], data['test']
+ num = len(img_paths)
+ assert num == len(labels) == len(test), 'get error ann file'
+ for i in range(num):
+ if not self.test_mode and test[i] == 1:
+ continue
+ if self.test_mode and test[i] == 0:
+ continue
+ img_path = self.backend.join_path(self.img_prefix,
+ img_paths[i])
+ gt_label = labels[i] - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+ else:
+ # second way
+ img_names, labels = data['fname'], data['class']
+ num = len(img_names)
+ assert num == len(labels), 'get error ann file'
+ for i in range(num):
+ img_path = self.backend.join_path(self.img_prefix,
+ img_names[i])
+ gt_label = labels[i] - 1
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/sun397.py b/mmpretrain/datasets/sun397.py
new file mode 100644
index 0000000000000000000000000000000000000000..1039a0690f8096082d5c55f89d743478fdf5b22d
--- /dev/null
+++ b/mmpretrain/datasets/sun397.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine import get_file_backend, list_from_file
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+from .categories import SUN397_CATEGORIES
+
+
+@DATASETS.register_module()
+class SUN397(BaseDataset):
+ """The SUN397 Dataset.
+
+ Support the `SUN397 Dataset `_ Dataset.
+ After downloading and decompression, the dataset directory structure is as follows.
+
+ SUN397 dataset directory: ::
+
+ SUN397
+ ├── SUN397
+ │ ├── a
+ │ │ ├── abbey
+ │ | | ├── sun_aaalbzqrimafwbiv.jpg
+ │ | | └── ...
+ │ │ ├── airplane_cabin
+ │ | | ├── sun_aadqdkqaslqqoblu.jpg
+ │ | | └── ...
+ │ | └── ...
+ │ ├── b
+ │ │ └── ...
+ │ ├── c
+ │ │ └── ...
+ │ └── ...
+ └── Partitions
+ ├── ClassName.txt
+ ├── Training_01.txt
+ ├── Testing_01.txt
+ └── ...
+
+ Args:
+ data_root (str): The root directory for Stanford Cars dataset.
+ split (str, optional): The dataset split, supports "train" and "test".
+ Default to "train".
+
+ Examples:
+ >>> from mmpretrain.datasets import SUN397
+ >>> train_dataset = SUN397(data_root='data/SUN397', split='train')
+ >>> train_dataset
+ Dataset SUN397
+ Number of samples: 19850
+ Number of categories: 397
+ Root of dataset: data/SUN397
+ >>> test_dataset = SUN397(data_root='data/SUN397', split='test')
+ >>> test_dataset
+ Dataset SUN397
+ Number of samples: 19850
+ Number of categories: 397
+ Root of dataset: data/SUN397
+
+ **Note that some images are not a jpg file although the name ends with ".jpg".
+ The backend of SUN397 should be "pillow" as below to read these images properly,**
+
+ .. code-block:: python
+
+ pipeline = [
+ dict(type='LoadImageFromFile', imdecode_backend='pillow'),
+ dict(type='RandomResizedCrop', scale=224),
+ dict(type='PackInputs')
+ ]
+ """ # noqa: E501
+
+ METAINFO = {'classes': SUN397_CATEGORIES}
+
+ def __init__(self, data_root: str, split: str = 'train', **kwargs):
+
+ splits = ['train', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+ if split == 'train':
+ ann_file = self.backend.join_path('Partitions', 'Training_01.txt')
+ else:
+ ann_file = self.backend.join_path('Partitions', 'Testing_01.txt')
+
+ data_prefix = 'SUN397'
+ test_mode = split == 'test'
+
+ super(SUN397, self).__init__(
+ ann_file=ann_file,
+ data_root=data_root,
+ test_mode=test_mode,
+ data_prefix=data_prefix,
+ **kwargs)
+
+ def load_data_list(self):
+ pairs = list_from_file(self.ann_file)
+ data_list = []
+ for pair in pairs:
+ img_path = self.backend.join_path(self.img_prefix, pair[1:])
+ items = pair.split('/')
+ class_name = '_'.join(items[2:-1])
+ gt_label = self.METAINFO['classes'].index(class_name)
+ info = dict(img_path=img_path, gt_label=gt_label)
+ data_list.append(info)
+
+ return data_list
+
+ def __getitem__(self, idx: int) -> dict:
+ try:
+ return super().__getitem__(idx)
+ except AttributeError:
+ raise RuntimeError(
+ 'Some images in the SUN397 dataset are not a jpg file '
+ 'although the name ends with ".jpg". The backend of SUN397 '
+ 'should be "pillow" to read these images properly.')
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Root of dataset: \t{self.data_root}',
+ ]
+ return body
diff --git a/mmpretrain/datasets/textvqa.py b/mmpretrain/datasets/textvqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..48a82b45ef1a4cc0bad2ab45b32b8ba8d28b2a60
--- /dev/null
+++ b/mmpretrain/datasets/textvqa.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import Counter
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class TextVQA(BaseDataset):
+ """TextVQA dataset.
+
+ val image:
+ https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
+ test image:
+ https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
+ val json:
+ https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
+ test json:
+ https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
+
+ folder structure:
+ data/textvqa
+ ├── annotations
+ │ ├── TextVQA_0.5.1_test.json
+ │ └── TextVQA_0.5.1_val.json
+ └── images
+ ├── test_images
+ └── train_images
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ question_file (str): Question file path.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to an empty string.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str,
+ ann_file: str = '',
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ annotations = mmengine.load(self.ann_file)['data']
+
+ data_list = []
+
+ for ann in annotations:
+
+ # ann example
+ # {
+ # 'question': 'what is the brand of...is camera?',
+ # 'image_id': '003a8ae2ef43b901',
+ # 'image_classes': [
+ # 'Cassette deck', 'Printer', ...
+ # ],
+ # 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
+ # 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
+ # 'image_width': 1024,
+ # 'image_height': 664,
+ # 'answers': [
+ # 'nous les gosses',
+ # 'dakota',
+ # 'clos culombu',
+ # 'dakota digital' ...
+ # ],
+ # 'question_tokens':
+ # ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
+ # 'question_id': 34602,
+ # 'set_name': 'val'
+ # }
+
+ data_info = dict(question=ann['question'])
+ data_info['question_id'] = ann['question_id']
+ data_info['image_id'] = ann['image_id']
+
+ img_path = mmengine.join_path(self.data_prefix['img_path'],
+ ann['image_id'] + '.jpg')
+ data_info['img_path'] = img_path
+
+ data_info['question_id'] = ann['question_id']
+
+ if 'answers' in ann:
+ answers = [item for item in ann.pop('answers')]
+ count = Counter(answers)
+ answer_weight = [i / len(answers) for i in count.values()]
+ data_info['gt_answer'] = list(count.keys())
+ data_info['gt_answer_weight'] = answer_weight
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cda99d59c9b147c1842adbffa6bb215f657c33c
--- /dev/null
+++ b/mmpretrain/datasets/transforms/__init__.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.transforms import (CenterCrop, LoadImageFromFile, Normalize,
+ RandomFlip, RandomGrayscale, RandomResize, Resize)
+
+from mmpretrain.registry import TRANSFORMS
+from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
+ Brightness, ColorTransform, Contrast, Cutout,
+ Equalize, GaussianBlur, Invert, Posterize,
+ RandAugment, Rotate, Sharpness, Shear, Solarize,
+ SolarizeAdd, Translate)
+from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
+ PILToNumpy, Transpose)
+from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
+ ColorJitter, EfficientNetCenterCrop,
+ EfficientNetRandomCrop, Lighting, RandomCrop,
+ RandomErasing, RandomResizedCrop,
+ RandomResizedCropAndInterpolationWithTwoPic,
+ RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
+from .utils import get_transform_idx, remove_transform
+from .wrappers import ApplyToList, MultiView
+
+for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
+ RandomGrayscale, RandomResize, Resize):
+ TRANSFORMS.register_module(module=t)
+
+__all__ = [
+ 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
+ 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
+ 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
+ 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
+ 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
+ 'PackInputs', 'Albumentations', 'EfficientNetRandomCrop',
+ 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
+ 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator',
+ 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
+ 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
+ 'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
+ 'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
+ 'remove_transform'
+]
diff --git a/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdfacfafafee9f52afaf2f31b9b44995cfb4d14f
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ad2b3fb23e081a2d5a876bc1df9fee7c5ac056a
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/auto_augment.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e0621832bc6aa3de7af7d311bf6cfe62db2b264
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/formatting.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3df804deb0c6b62eeff44e650497d8c98fb3a207
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/processing.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55796fbf9d0df64784b71a70c978c374036a2feb
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/utils.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc b/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca2bd62c3a4cc8ce1471d00aa950c317883b8ef4
Binary files /dev/null and b/mmpretrain/datasets/transforms/__pycache__/wrappers.cpython-38.pyc differ
diff --git a/mmpretrain/datasets/transforms/auto_augment.py b/mmpretrain/datasets/transforms/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..03b057b850a4fd797f8f5c0672f60c6c20e44273
--- /dev/null
+++ b/mmpretrain/datasets/transforms/auto_augment.py
@@ -0,0 +1,1244 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+from copy import deepcopy
+from math import ceil
+from numbers import Number
+from typing import List, Optional, Sequence, Tuple, Union
+
+import mmcv
+import numpy as np
+from mmcv.transforms import BaseTransform, Compose, RandomChoice
+from mmcv.transforms.utils import cache_randomness
+from mmengine.utils import is_list_of, is_seq_of
+from PIL import Image, ImageFilter
+
+from mmpretrain.registry import TRANSFORMS
+
+
+def merge_hparams(policy: dict, hparams: dict) -> dict:
+ """Merge hyperparameters into policy config.
+
+ Only merge partial hyperparameters required of the policy.
+
+ Args:
+ policy (dict): Original policy config dict.
+ hparams (dict): Hyperparameters need to be merged.
+
+ Returns:
+ dict: Policy config dict after adding ``hparams``.
+ """
+ policy = deepcopy(policy)
+ op = TRANSFORMS.get(policy['type'])
+ assert op is not None, f'Invalid policy type "{policy["type"]}".'
+
+ op_args = inspect.getfullargspec(op.__init__).args
+ for key, value in hparams.items():
+ if key in op_args and key not in policy:
+ policy[key] = value
+ return policy
+
+
+@TRANSFORMS.register_module()
+class AutoAugment(RandomChoice):
+ """Auto augmentation.
+
+ This data augmentation is proposed in `AutoAugment: Learning Augmentation
+ Policies from Data `_.
+
+ Args:
+ policies (str | list[list[dict]]): The policies of auto augmentation.
+ If string, use preset policies collection like "imagenet". If list,
+ Each item is a sub policies, composed by several augmentation
+ policy dicts. When AutoAugment is called, a random sub policies in
+ ``policies`` will be selected to augment images.
+ hparams (dict): Configs of hyperparameters. Hyperparameters will be
+ used in policies that require these arguments if these arguments
+ are not set in policy dicts. Defaults to ``dict(pad_val=128)``.
+
+ .. admonition:: Available preset policies
+
+ - ``"imagenet"``: Policy for ImageNet, come from
+ `DeepVoltaire/AutoAugment`_
+
+ .. _DeepVoltaire/AutoAugment: https://github.com/DeepVoltaire/AutoAugment
+ """
+
+ def __init__(self,
+ policies: Union[str, List[List[dict]]],
+ hparams: dict = dict(pad_val=128)):
+ if isinstance(policies, str):
+ assert policies in AUTOAUG_POLICIES, 'Invalid policies, ' \
+ f'please choose from {list(AUTOAUG_POLICIES.keys())}.'
+ policies = AUTOAUG_POLICIES[policies]
+ self.hparams = hparams
+ self.policies = [[merge_hparams(t, hparams) for t in sub]
+ for sub in policies]
+ transforms = [[TRANSFORMS.build(t) for t in sub] for sub in policies]
+
+ super().__init__(transforms=transforms)
+
+ def __repr__(self) -> str:
+ policies_str = ''
+ for sub in self.policies:
+ policies_str += '\n ' + ', \t'.join([t['type'] for t in sub])
+
+ repr_str = self.__class__.__name__
+ repr_str += f'(policies:{policies_str}\n)'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandAugment(BaseTransform):
+ r"""Random augmentation.
+
+ This data augmentation is proposed in `RandAugment: Practical automated
+ data augmentation with a reduced search space
+ `_.
+
+ Args:
+ policies (str | list[dict]): The policies of random augmentation.
+ If string, use preset policies collection like "timm_increasing".
+ If list, each item is one specific augmentation policy dict.
+ The policy dict shall should have these keys:
+
+ - ``type`` (str), The type of augmentation.
+ - ``magnitude_range`` (Sequence[number], optional): For those
+ augmentation have magnitude, you need to specify the magnitude
+ level mapping range. For example, assume ``total_level`` is 10,
+ ``magnitude_level=3`` specify magnitude is 3 if
+ ``magnitude_range=(0, 10)`` while specify magnitude is 7 if
+ ``magnitude_range=(10, 0)``.
+ - other keyword arguments of the augmentation.
+
+ num_policies (int): Number of policies to select from policies each
+ time.
+ magnitude_level (int | float): Magnitude level for all the augmentation
+ selected.
+ magnitude_std (Number | str): Deviation of magnitude noise applied.
+
+ - If positive number, the magnitude obeys normal distribution
+ :math:`\mathcal{N}(magnitude_level, magnitude_std)`.
+ - If 0 or negative number, magnitude remains unchanged.
+ - If str "inf", the magnitude obeys uniform distribution
+ :math:`Uniform(min, magnitude)`.
+ total_level (int | float): Total level for the magnitude. Defaults to
+ 10.
+ hparams (dict): Configs of hyperparameters. Hyperparameters will be
+ used in policies that require these arguments if these arguments
+ are not set in policy dicts. Defaults to ``dict(pad_val=128)``.
+
+ .. admonition:: Available preset policies
+
+ - ``"timm_increasing"``: The ``_RAND_INCREASING_TRANSFORMS`` policy
+ from `timm`_
+
+ .. _timm: https://github.com/rwightman/pytorch-image-models
+
+ Examples:
+
+ To use "timm-increasing" policies collection, select two policies every
+ time, and magnitude_level of every policy is 6 (total is 10 by default)
+
+ >>> import numpy as np
+ >>> from mmpretrain.datasets import RandAugment
+ >>> transform = RandAugment(
+ ... policies='timm_increasing',
+ ... num_policies=2,
+ ... magnitude_level=6,
+ ... )
+ >>> data = {'img': np.random.randint(0, 256, (224, 224, 3))}
+ >>> results = transform(data)
+ >>> print(results['img'].shape)
+ (224, 224, 3)
+
+ If you want the ``magnitude_level`` randomly changes every time, you
+ can use ``magnitude_std`` to specify the random distribution. For
+ example, a normal distribution :math:`\mathcal{N}(6, 0.5)`.
+
+ >>> transform = RandAugment(
+ ... policies='timm_increasing',
+ ... num_policies=2,
+ ... magnitude_level=6,
+ ... magnitude_std=0.5,
+ ... )
+
+ You can also use your own policies:
+
+ >>> policies = [
+ ... dict(type='AutoContrast'),
+ ... dict(type='Rotate', magnitude_range=(0, 30)),
+ ... dict(type='ColorTransform', magnitude_range=(0, 0.9)),
+ ... ]
+ >>> transform = RandAugment(
+ ... policies=policies,
+ ... num_policies=2,
+ ... magnitude_level=6
+ ... )
+
+ Note:
+ ``magnitude_std`` will introduce some randomness to policy, modified by
+ https://github.com/rwightman/pytorch-image-models.
+
+ When magnitude_std=0, we calculate the magnitude as follows:
+
+ .. math::
+ \text{magnitude} = \frac{\text{magnitude_level}}
+ {\text{totallevel}} \times (\text{val2} - \text{val1})
+ + \text{val1}
+ """
+
+ def __init__(self,
+ policies: Union[str, List[dict]],
+ num_policies: int,
+ magnitude_level: int,
+ magnitude_std: Union[Number, str] = 0.,
+ total_level: int = 10,
+ hparams: dict = dict(pad_val=128)):
+ if isinstance(policies, str):
+ assert policies in RANDAUG_POLICIES, 'Invalid policies, ' \
+ f'please choose from {list(RANDAUG_POLICIES.keys())}.'
+ policies = RANDAUG_POLICIES[policies]
+
+ assert is_list_of(policies, dict), 'policies must be a list of dict.'
+
+ assert isinstance(magnitude_std, (Number, str)), \
+ '`magnitude_std` must be of number or str type, ' \
+ f'got {type(magnitude_std)} instead.'
+ if isinstance(magnitude_std, str):
+ assert magnitude_std == 'inf', \
+ '`magnitude_std` must be of number or "inf", ' \
+ f'got "{magnitude_std}" instead.'
+
+ assert num_policies > 0, 'num_policies must be greater than 0.'
+ assert magnitude_level >= 0, 'magnitude_level must be no less than 0.'
+ assert total_level > 0, 'total_level must be greater than 0.'
+
+ self.num_policies = num_policies
+ self.magnitude_level = magnitude_level
+ self.magnitude_std = magnitude_std
+ self.total_level = total_level
+ self.hparams = hparams
+ self.policies = []
+ self.transforms = []
+
+ randaug_cfg = dict(
+ magnitude_level=magnitude_level,
+ total_level=total_level,
+ magnitude_std=magnitude_std)
+
+ for policy in policies:
+ self._check_policy(policy)
+ policy = merge_hparams(policy, hparams)
+ policy.pop('magnitude_key', None) # For backward compatibility
+ if 'magnitude_range' in policy:
+ policy.update(randaug_cfg)
+ self.policies.append(policy)
+ self.transforms.append(TRANSFORMS.build(policy))
+
+ def __iter__(self):
+ """Iterate all transforms."""
+ return iter(self.transforms)
+
+ def _check_policy(self, policy):
+ """Check whether the sub-policy dict is available."""
+ assert isinstance(policy, dict) and 'type' in policy, \
+ 'Each policy must be a dict with key "type".'
+ type_name = policy['type']
+
+ if 'magnitude_range' in policy:
+ magnitude_range = policy['magnitude_range']
+ assert is_seq_of(magnitude_range, Number), \
+ f'`magnitude_range` of RandAugment policy {type_name} ' \
+ 'should be a sequence with two numbers.'
+
+ @cache_randomness
+ def random_policy_indices(self) -> np.ndarray:
+ """Return the random chosen transform indices."""
+ indices = np.arange(len(self.policies))
+ return np.random.choice(indices, size=self.num_policies).tolist()
+
+ def transform(self, results: dict) -> Optional[dict]:
+ """Randomly choose a sub-policy to apply."""
+
+ chosen_policies = [
+ self.transforms[i] for i in self.random_policy_indices()
+ ]
+
+ sub_pipeline = Compose(chosen_policies)
+ return sub_pipeline(results)
+
+ def __repr__(self) -> str:
+ policies_str = ''
+ for policy in self.policies:
+ policies_str += '\n ' + f'{policy["type"]}'
+ if 'magnitude_range' in policy:
+ val1, val2 = policy['magnitude_range']
+ policies_str += f' ({val1}, {val2})'
+
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_policies={self.num_policies}, '
+ repr_str += f'magnitude_level={self.magnitude_level}, '
+ repr_str += f'total_level={self.total_level}, '
+ repr_str += f'policies:{policies_str}\n)'
+ return repr_str
+
+
+class BaseAugTransform(BaseTransform):
+ r"""The base class of augmentation transform for RandAugment.
+
+ This class provides several common attributions and methods to support the
+ magnitude level mapping and magnitude level randomness in
+ :class:`RandAugment`.
+
+ Args:
+ magnitude_level (int | float): Magnitude level.
+ magnitude_range (Sequence[number], optional): For augmentation have
+ magnitude argument, maybe "magnitude", "angle" or other, you can
+ specify the magnitude level mapping range to generate the magnitude
+ argument. For example, assume ``total_level`` is 10,
+ ``magnitude_level=3`` specify magnitude is 3 if
+ ``magnitude_range=(0, 10)`` while specify magnitude is 7 if
+ ``magnitude_range=(10, 0)``. Defaults to None.
+ magnitude_std (Number | str): Deviation of magnitude noise applied.
+
+ - If positive number, the magnitude obeys normal distribution
+ :math:`\mathcal{N}(magnitude, magnitude_std)`.
+ - If 0 or negative number, magnitude remains unchanged.
+ - If str "inf", the magnitude obeys uniform distribution
+ :math:`Uniform(min, magnitude)`.
+
+ Defaults to 0.
+ total_level (int | float): Total level for the magnitude. Defaults to
+ 10.
+ prob (float): The probability for performing transformation therefore
+ should be in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.
+ """
+
+ def __init__(self,
+ magnitude_level: int = 10,
+ magnitude_range: Tuple[float, float] = None,
+ magnitude_std: Union[str, float] = 0.,
+ total_level: int = 10,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5):
+ self.magnitude_level = magnitude_level
+ self.magnitude_range = magnitude_range
+ self.magnitude_std = magnitude_std
+ self.total_level = total_level
+ self.prob = prob
+ self.random_negative_prob = random_negative_prob
+
+ @cache_randomness
+ def random_disable(self):
+ """Randomly disable the transform."""
+ return np.random.rand() > self.prob
+
+ @cache_randomness
+ def random_magnitude(self):
+ """Randomly generate magnitude."""
+ magnitude = self.magnitude_level
+ # if magnitude_std is positive number or 'inf', move
+ # magnitude_value randomly.
+ if self.magnitude_std == 'inf':
+ magnitude = np.random.uniform(0, magnitude)
+ elif self.magnitude_std > 0:
+ magnitude = np.random.normal(magnitude, self.magnitude_std)
+ magnitude = np.clip(magnitude, 0, self.total_level)
+
+ val1, val2 = self.magnitude_range
+ magnitude = (magnitude / self.total_level) * (val2 - val1) + val1
+ return magnitude
+
+ @cache_randomness
+ def random_negative(self, value):
+ """Randomly negative the value."""
+ if np.random.rand() < self.random_negative_prob:
+ return -value
+ else:
+ return value
+
+ def extra_repr(self):
+ """Extra repr string when auto-generating magnitude is enabled."""
+ if self.magnitude_range is not None:
+ repr_str = f', magnitude_level={self.magnitude_level}, '
+ repr_str += f'magnitude_range={self.magnitude_range}, '
+ repr_str += f'magnitude_std={self.magnitude_std}, '
+ repr_str += f'total_level={self.total_level}, '
+ return repr_str
+ else:
+ return ''
+
+
+@TRANSFORMS.register_module()
+class Shear(BaseAugTransform):
+ """Shear images.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for shear. If None,
+ generate from ``magnitude_range``, see :class:`BaseAugTransform`.
+ Defaults to None.
+ pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
+ If a sequence of length 3, it is used to pad_val R, G, B channels
+ respectively. Defaults to 128.
+ prob (float): The probability for performing shear therefore should be
+ in range [0, 1]. Defaults to 0.5.
+ direction (str): The shearing direction. Options are 'horizontal' and
+ 'vertical'. Defaults to 'horizontal'.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ interpolation (str): Interpolation method. Options are 'nearest',
+ 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'bicubic'.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ pad_val: Union[int, Sequence[int]] = 128,
+ prob: float = 0.5,
+ direction: str = 'horizontal',
+ random_negative_prob: float = 0.5,
+ interpolation: str = 'bicubic',
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+ if isinstance(pad_val, Sequence):
+ self.pad_val = tuple(pad_val)
+ else:
+ self.pad_val = pad_val
+
+ assert direction in ('horizontal', 'vertical'), 'direction must be ' \
+ f'either "horizontal" or "vertical", got "{direction}" instead.'
+ self.direction = direction
+
+ self.interpolation = interpolation
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_sheared = mmcv.imshear(
+ img,
+ magnitude,
+ direction=self.direction,
+ border_value=self.pad_val,
+ interpolation=self.interpolation)
+ results['img'] = img_sheared.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'direction={self.direction}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Translate(BaseAugTransform):
+ """Translate images.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for translate. Note
+ that the offset is calculated by magnitude * size in the
+ corresponding direction. With a magnitude of 1, the whole image
+ will be moved out of the range. If None, generate from
+ ``magnitude_range``, see :class:`BaseAugTransform`.
+ pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
+ If a sequence of length 3, it is used to pad_val R, G, B channels
+ respectively. Defaults to 128.
+ prob (float): The probability for performing translate therefore should
+ be in range [0, 1]. Defaults to 0.5.
+ direction (str): The translating direction. Options are 'horizontal'
+ and 'vertical'. Defaults to 'horizontal'.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ interpolation (str): Interpolation method. Options are 'nearest',
+ 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ pad_val: Union[int, Sequence[int]] = 128,
+ prob: float = 0.5,
+ direction: str = 'horizontal',
+ random_negative_prob: float = 0.5,
+ interpolation: str = 'nearest',
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+ if isinstance(pad_val, Sequence):
+ self.pad_val = tuple(pad_val)
+ else:
+ self.pad_val = pad_val
+
+ assert direction in ('horizontal', 'vertical'), 'direction must be ' \
+ f'either "horizontal" or "vertical", got "{direction}" instead.'
+ self.direction = direction
+
+ self.interpolation = interpolation
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ height, width = img.shape[:2]
+ if self.direction == 'horizontal':
+ offset = magnitude * width
+ else:
+ offset = magnitude * height
+ img_translated = mmcv.imtranslate(
+ img,
+ offset,
+ direction=self.direction,
+ border_value=self.pad_val,
+ interpolation=self.interpolation)
+ results['img'] = img_translated.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'direction={self.direction}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Rotate(BaseAugTransform):
+ """Rotate images.
+
+ Args:
+ angle (float, optional): The angle used for rotate. Positive values
+ stand for clockwise rotation. If None, generate from
+ ``magnitude_range``, see :class:`BaseAugTransform`.
+ Defaults to None.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If None, the center of the image will be used.
+ Defaults to None.
+ scale (float): Isotropic scale factor. Defaults to 1.0.
+ pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
+ If a sequence of length 3, it is used to pad_val R, G, B channels
+ respectively. Defaults to 128.
+ prob (float): The probability for performing rotate therefore should be
+ in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the angle
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ interpolation (str): Interpolation method. Options are 'nearest',
+ 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to 'nearest'.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ angle: Optional[float] = None,
+ center: Optional[Tuple[float]] = None,
+ scale: float = 1.0,
+ pad_val: Union[int, Sequence[int]] = 128,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5,
+ interpolation: str = 'nearest',
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (angle is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `angle` and `magnitude_range`.'
+
+ self.angle = angle
+ self.center = center
+ self.scale = scale
+ if isinstance(pad_val, Sequence):
+ self.pad_val = tuple(pad_val)
+ else:
+ self.pad_val = pad_val
+
+ self.interpolation = interpolation
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.angle is not None:
+ angle = self.random_negative(self.angle)
+ else:
+ angle = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_rotated = mmcv.imrotate(
+ img,
+ angle,
+ center=self.center,
+ scale=self.scale,
+ border_value=self.pad_val,
+ interpolation=self.interpolation)
+ results['img'] = img_rotated.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(angle={self.angle}, '
+ repr_str += f'center={self.center}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class AutoContrast(BaseAugTransform):
+ """Auto adjust image contrast.
+
+ Args:
+ prob (float): The probability for performing auto contrast
+ therefore should be in range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self, prob: float = 0.5, **kwargs):
+ super().__init__(prob=prob, **kwargs)
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ img = results['img']
+ img_contrasted = mmcv.auto_contrast(img)
+ results['img'] = img_contrasted.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Invert(BaseAugTransform):
+ """Invert images.
+
+ Args:
+ prob (float): The probability for performing invert therefore should
+ be in range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self, prob: float = 0.5, **kwargs):
+ super().__init__(prob=prob, **kwargs)
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ img = results['img']
+ img_inverted = mmcv.iminvert(img)
+ results['img'] = img_inverted.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Equalize(BaseAugTransform):
+ """Equalize the image histogram.
+
+ Args:
+ prob (float): The probability for performing equalize therefore should
+ be in range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self, prob: float = 0.5, **kwargs):
+ super().__init__(prob=prob, **kwargs)
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ img = results['img']
+ img_equalized = mmcv.imequalize(img)
+ results['img'] = img_equalized.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Solarize(BaseAugTransform):
+ """Solarize images (invert all pixel values above a threshold).
+
+ Args:
+ thr (int | float | None): The threshold above which the pixels value
+ will be inverted. If None, generate from ``magnitude_range``,
+ see :class:`BaseAugTransform`. Defaults to None.
+ prob (float): The probability for solarizing therefore should be in
+ range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ thr: Union[int, float, None] = None,
+ prob: float = 0.5,
+ **kwargs):
+ super().__init__(prob=prob, random_negative_prob=0., **kwargs)
+ assert (thr is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `thr` and `magnitude_range`.'
+
+ self.thr = thr
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.thr is not None:
+ thr = self.thr
+ else:
+ thr = self.random_magnitude()
+
+ img = results['img']
+ img_solarized = mmcv.solarize(img, thr=thr)
+ results['img'] = img_solarized.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(thr={self.thr}, '
+ repr_str += f'prob={self.prob}{self.extra_repr()}))'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class SolarizeAdd(BaseAugTransform):
+ """SolarizeAdd images (add a certain value to pixels below a threshold).
+
+ Args:
+ magnitude (int | float | None): The value to be added to pixels below
+ the thr. If None, generate from ``magnitude_range``, see
+ :class:`BaseAugTransform`. Defaults to None.
+ thr (int | float): The threshold below which the pixels value will be
+ adjusted.
+ prob (float): The probability for solarizing therefore should be in
+ range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ thr: Union[int, float] = 128,
+ prob: float = 0.5,
+ **kwargs):
+ super().__init__(prob=prob, random_negative_prob=0., **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+
+ assert isinstance(thr, (int, float)), 'The thr type must '\
+ f'be int or float, but got {type(thr)} instead.'
+ self.thr = thr
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.magnitude
+ else:
+ magnitude = self.random_magnitude()
+
+ img = results['img']
+ img_solarized = np.where(img < self.thr,
+ np.minimum(img + magnitude, 255), img)
+ results['img'] = img_solarized.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'thr={self.thr}, '
+ repr_str += f'prob={self.prob}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Posterize(BaseAugTransform):
+ """Posterize images (reduce the number of bits for each color channel).
+
+ Args:
+ bits (int, optional): Number of bits for each pixel in the output img,
+ which should be less or equal to 8. If None, generate from
+ ``magnitude_range``, see :class:`BaseAugTransform`.
+ Defaults to None.
+ prob (float): The probability for posterizing therefore should be in
+ range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ bits: Optional[int] = None,
+ prob: float = 0.5,
+ **kwargs):
+ super().__init__(prob=prob, random_negative_prob=0., **kwargs)
+ assert (bits is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `bits` and `magnitude_range`.'
+
+ if bits is not None:
+ assert bits <= 8, \
+ f'The bits must be less than 8, got {bits} instead.'
+ self.bits = bits
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.bits is not None:
+ bits = self.bits
+ else:
+ bits = self.random_magnitude()
+
+ # To align timm version, we need to round up to integer here.
+ bits = ceil(bits)
+
+ img = results['img']
+ img_posterized = mmcv.posterize(img, bits=bits)
+ results['img'] = img_posterized.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(bits={self.bits}, '
+ repr_str += f'prob={self.prob}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Contrast(BaseAugTransform):
+ """Adjust images contrast.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for adjusting
+ contrast. A positive magnitude would enhance the contrast and
+ a negative magnitude would make the image grayer. A magnitude=0
+ gives the origin img. If None, generate from ``magnitude_range``,
+ see :class:`BaseAugTransform`. Defaults to None.
+ prob (float): The probability for performing contrast adjusting
+ therefore should be in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5,
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude)
+ results['img'] = img_contrasted.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}'
+ repr_str += f'{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class ColorTransform(BaseAugTransform):
+ """Adjust images color balance.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for color transform.
+ A positive magnitude would enhance the color and a negative
+ magnitude would make the image grayer. A magnitude=0 gives the
+ origin img. If None, generate from ``magnitude_range``, see
+ :class:`BaseAugTransform`. Defaults to None.
+ prob (float): The probability for performing ColorTransform therefore
+ should be in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5,
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude)
+ results['img'] = img_color_adjusted.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}'
+ repr_str += f'{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Brightness(BaseAugTransform):
+ """Adjust images brightness.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for adjusting
+ brightness. A positive magnitude would enhance the brightness and a
+ negative magnitude would make the image darker. A magnitude=0 gives
+ the origin img. If None, generate from ``magnitude_range``, see
+ :class:`BaseAugTransform`. Defaults to None.
+ prob (float): The probability for performing brightness adjusting
+ therefore should be in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5,
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude)
+ results['img'] = img_brightened.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}'
+ repr_str += f'{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Sharpness(BaseAugTransform):
+ """Adjust images sharpness.
+
+ Args:
+ magnitude (int | float | None): The magnitude used for adjusting
+ sharpness. A positive magnitude would enhance the sharpness and a
+ negative magnitude would make the image bulr. A magnitude=0 gives
+ the origin img. If None, generate from ``magnitude_range``, see
+ :class:`BaseAugTransform`. Defaults to None.
+ prob (float): The probability for performing sharpness adjusting
+ therefore should be in range [0, 1]. Defaults to 0.5.
+ random_negative_prob (float): The probability that turns the magnitude
+ negative, which should be in range [0,1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ magnitude: Union[int, float, None] = None,
+ prob: float = 0.5,
+ random_negative_prob: float = 0.5,
+ **kwargs):
+ super().__init__(
+ prob=prob, random_negative_prob=random_negative_prob, **kwargs)
+ assert (magnitude is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `magnitude` and `magnitude_range`.'
+
+ self.magnitude = magnitude
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.magnitude is not None:
+ magnitude = self.random_negative(self.magnitude)
+ else:
+ magnitude = self.random_negative(self.random_magnitude())
+
+ img = results['img']
+ img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude)
+ results['img'] = img_sharpened.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(magnitude={self.magnitude}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}'
+ repr_str += f'{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Cutout(BaseAugTransform):
+ """Cutout images.
+
+ Args:
+ shape (int | tuple(int) | None): Expected cutout shape (h, w).
+ If given as a single value, the value will be used for both h and
+ w. If None, generate from ``magnitude_range``, see
+ :class:`BaseAugTransform`. Defaults to None.
+ pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
+ If it is a sequence, it must have the same length with the image
+ channels. Defaults to 128.
+ prob (float): The probability for performing cutout therefore should
+ be in range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ shape: Union[int, Tuple[int], None] = None,
+ pad_val: Union[int, Sequence[int]] = 128,
+ prob: float = 0.5,
+ **kwargs):
+ super().__init__(prob=prob, random_negative_prob=0., **kwargs)
+ assert (shape is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `shape` and `magnitude_range`.'
+
+ self.shape = shape
+ if isinstance(pad_val, Sequence):
+ self.pad_val = tuple(pad_val)
+ else:
+ self.pad_val = pad_val
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.shape is not None:
+ shape = self.shape
+ else:
+ shape = int(self.random_magnitude())
+
+ img = results['img']
+ img_cutout = mmcv.cutout(img, shape, pad_val=self.pad_val)
+ results['img'] = img_cutout.astype(img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(shape={self.shape}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'prob={self.prob}{self.extra_repr()})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class GaussianBlur(BaseAugTransform):
+ """Gaussian blur images.
+
+ Args:
+ radius (int, float, optional): The blur radius. If None, generate from
+ ``magnitude_range``, see :class:`BaseAugTransform`.
+ Defaults to None.
+ prob (float): The probability for posterizing therefore should be in
+ range [0, 1]. Defaults to 0.5.
+ **kwargs: Other keyword arguments of :class:`BaseAugTransform`.
+ """
+
+ def __init__(self,
+ radius: Union[int, float, None] = None,
+ prob: float = 0.5,
+ **kwargs):
+ super().__init__(prob=prob, random_negative_prob=0., **kwargs)
+ assert (radius is None) ^ (self.magnitude_range is None), \
+ 'Please specify only one of `radius` and `magnitude_range`.'
+
+ self.radius = radius
+
+ def transform(self, results):
+ """Apply transform to results."""
+ if self.random_disable():
+ return results
+
+ if self.radius is not None:
+ radius = self.radius
+ else:
+ radius = self.random_magnitude()
+
+ img = results['img']
+ pil_img = Image.fromarray(img)
+ pil_img.filter(ImageFilter.GaussianBlur(radius=radius))
+ results['img'] = np.array(pil_img, dtype=img.dtype)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(radius={self.radius}, '
+ repr_str += f'prob={self.prob}{self.extra_repr()})'
+ return repr_str
+
+
+# yapf: disable
+# flake8: noqa
+AUTOAUG_POLICIES = {
+ # Policy for ImageNet, refers to
+ # https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
+ 'imagenet': [
+ [dict(type='Posterize', bits=4, prob=0.4), dict(type='Rotate', angle=30., prob=0.6)],
+ [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)],
+ [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)],
+ [dict(type='Posterize', bits=5, prob=0.6), dict(type='Posterize', bits=5, prob=0.6)],
+ [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)],
+ [dict(type='Equalize', prob=0.4), dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)],
+ [dict(type='Solarize', thr=256 / 9 * 6, prob=0.6), dict(type='Equalize', prob=0.6)],
+ [dict(type='Posterize', bits=6, prob=0.8), dict(type='Equalize', prob=1.)],
+ [dict(type='Rotate', angle=10., prob=0.2), dict(type='Solarize', thr=256 / 9, prob=0.6)],
+ [dict(type='Equalize', prob=0.6), dict(type='Posterize', bits=5, prob=0.4)],
+ [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0., prob=0.4)],
+ [dict(type='Rotate', angle=30., prob=0.4), dict(type='Equalize', prob=0.6)],
+ [dict(type='Equalize', prob=0.0), dict(type='Equalize', prob=0.8)],
+ [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)],
+ [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)],
+ [dict(type='Rotate', angle=30 / 9 * 8, prob=0.8), dict(type='ColorTransform', magnitude=0.2, prob=1.)],
+ [dict(type='ColorTransform', magnitude=0.8, prob=0.8), dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)],
+ [dict(type='Sharpness', magnitude=0.7, prob=0.4), dict(type='Invert', prob=0.6)],
+ [dict(type='Shear', magnitude=0.3 / 9 * 5, prob=0.6, direction='horizontal'), dict(type='Equalize', prob=1.)],
+ [dict(type='ColorTransform', magnitude=0., prob=0.4), dict(type='Equalize', prob=0.6)],
+ [dict(type='Equalize', prob=0.4), dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)],
+ [dict(type='Solarize', thr=256 / 9 * 4, prob=0.6), dict(type='AutoContrast', prob=0.6)],
+ [dict(type='Invert', prob=0.6), dict(type='Equalize', prob=1.)],
+ [dict(type='ColorTransform', magnitude=0.4, prob=0.6), dict(type='Contrast', magnitude=0.8, prob=1.)],
+ [dict(type='Equalize', prob=0.8), dict(type='Equalize', prob=0.6)],
+ ],
+}
+
+RANDAUG_POLICIES = {
+ # Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
+ 'timm_increasing': [
+ dict(type='AutoContrast'),
+ dict(type='Equalize'),
+ dict(type='Invert'),
+ dict(type='Rotate', magnitude_range=(0, 30)),
+ dict(type='Posterize', magnitude_range=(4, 0)),
+ dict(type='Solarize', magnitude_range=(256, 0)),
+ dict(type='SolarizeAdd', magnitude_range=(0, 110)),
+ dict(type='ColorTransform', magnitude_range=(0, 0.9)),
+ dict(type='Contrast', magnitude_range=(0, 0.9)),
+ dict(type='Brightness', magnitude_range=(0, 0.9)),
+ dict(type='Sharpness', magnitude_range=(0, 0.9)),
+ dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'),
+ dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'),
+ dict(type='Translate', magnitude_range=(0, 0.45), direction='horizontal'),
+ dict(type='Translate', magnitude_range=(0, 0.45), direction='vertical'),
+ ],
+ 'simple_increasing': [
+ dict(type='AutoContrast'),
+ dict(type='Equalize'),
+ dict(type='Rotate', magnitude_range=(0, 30)),
+ dict(type='Shear', magnitude_range=(0, 0.3), direction='horizontal'),
+ dict(type='Shear', magnitude_range=(0, 0.3), direction='vertical'),
+ ],
+}
diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d331636a883ce602e419e0867aea7b513b4d87
--- /dev/null
+++ b/mmpretrain/datasets/transforms/formatting.py
@@ -0,0 +1,353 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import defaultdict
+from collections.abc import Sequence
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+from mmcv.transforms import BaseTransform
+from mmengine.utils import is_str
+from PIL import Image
+
+from mmpretrain.registry import TRANSFORMS
+from mmpretrain.structures import DataSample, MultiTaskDataSample
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+ """
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(
+ f'Type {type(data)} cannot be converted to tensor.'
+ 'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
+ '`Sequence`, `int` and `float`')
+
+
+@TRANSFORMS.register_module()
+class PackInputs(BaseTransform):
+ """Pack the inputs data.
+
+ **Required Keys:**
+
+ - ``input_key``
+ - ``*algorithm_keys``
+ - ``*meta_keys``
+
+ **Deleted Keys:**
+
+ All other keys in the dict.
+
+ **Added Keys:**
+
+ - inputs (:obj:`torch.Tensor`): The forward data of models.
+ - data_samples (:obj:`~mmpretrain.structures.DataSample`): The
+ annotation info of the sample.
+
+ Args:
+ input_key (str): The key of element to feed into the model forwarding.
+ Defaults to 'img'.
+ algorithm_keys (Sequence[str]): The keys of custom elements to be used
+ in the algorithm. Defaults to an empty tuple.
+ meta_keys (Sequence[str]): The keys of meta information to be saved in
+ the data sample. Defaults to :attr:`PackInputs.DEFAULT_META_KEYS`.
+
+ .. admonition:: Default algorithm keys
+
+ Besides the specified ``algorithm_keys``, we will set some default keys
+ into the output data sample and do some formatting. Therefore, you
+ don't need to set these keys in the ``algorithm_keys``.
+
+ - ``gt_label``: The ground-truth label. The value will be converted
+ into a 1-D tensor.
+ - ``gt_score``: The ground-truth score. The value will be converted
+ into a 1-D tensor.
+ - ``mask``: The mask for some self-supervise tasks. The value will
+ be converted into a tensor.
+
+ .. admonition:: Default meta keys
+
+ - ``sample_idx``: The id of the image sample.
+ - ``img_path``: The path to the image file.
+ - ``ori_shape``: The original shape of the image as a tuple (H, W).
+ - ``img_shape``: The shape of the image after the pipeline as a
+ tuple (H, W).
+ - ``scale_factor``: The scale factor between the resized image and
+ the original image.
+ - ``flip``: A boolean indicating if image flip transform was used.
+ - ``flip_direction``: The flipping direction.
+ """
+
+ DEFAULT_META_KEYS = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor', 'flip', 'flip_direction')
+
+ def __init__(self,
+ input_key='img',
+ algorithm_keys=(),
+ meta_keys=DEFAULT_META_KEYS):
+ self.input_key = input_key
+ self.algorithm_keys = algorithm_keys
+ self.meta_keys = meta_keys
+
+ @staticmethod
+ def format_input(input_):
+ if isinstance(input_, list):
+ return [PackInputs.format_input(item) for item in input_]
+ elif isinstance(input_, np.ndarray):
+ if input_.ndim == 2: # For grayscale image.
+ input_ = np.expand_dims(input_, -1)
+ if input_.ndim == 3 and not input_.flags.c_contiguous:
+ input_ = np.ascontiguousarray(input_.transpose(2, 0, 1))
+ input_ = to_tensor(input_)
+ elif input_.ndim == 3:
+ # convert to tensor first to accelerate, see
+ # https://github.com/open-mmlab/mmdetection/pull/9533
+ input_ = to_tensor(input_).permute(2, 0, 1).contiguous()
+ else:
+ # convert input with other shape to tensor without permute,
+ # like video input (num_crops, C, T, H, W).
+ input_ = to_tensor(input_)
+ elif isinstance(input_, Image.Image):
+ input_ = F.pil_to_tensor(input_)
+ elif not isinstance(input_, torch.Tensor):
+ raise TypeError(f'Unsupported input type {type(input_)}.')
+
+ return input_
+
+ def transform(self, results: dict) -> dict:
+ """Method to pack the input data."""
+
+ packed_results = dict()
+ if self.input_key in results:
+ input_ = results[self.input_key]
+ packed_results['inputs'] = self.format_input(input_)
+
+ data_sample = DataSample()
+
+ # Set default keys
+ if 'gt_label' in results:
+ data_sample.set_gt_label(results['gt_label'])
+ if 'gt_score' in results:
+ data_sample.set_gt_score(results['gt_score'])
+ if 'mask' in results:
+ data_sample.set_mask(results['mask'])
+
+ # Set custom algorithm keys
+ for key in self.algorithm_keys:
+ if key in results:
+ data_sample.set_field(results[key], key)
+
+ # Set meta keys
+ for key in self.meta_keys:
+ if key in results:
+ data_sample.set_field(results[key], key, field_type='metainfo')
+
+ packed_results['data_samples'] = data_sample
+ return packed_results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f"(input_key='{self.input_key}', "
+ repr_str += f'algorithm_keys={self.algorithm_keys}, '
+ repr_str += f'meta_keys={self.meta_keys})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class PackMultiTaskInputs(BaseTransform):
+ """Convert all image labels of multi-task dataset to a dict of tensor.
+
+ Args:
+ multi_task_fields (Sequence[str]):
+ input_key (str):
+ task_handlers (dict):
+ """
+
+ def __init__(self,
+ multi_task_fields,
+ input_key='img',
+ task_handlers=dict()):
+ self.multi_task_fields = multi_task_fields
+ self.input_key = input_key
+ self.task_handlers = defaultdict(PackInputs)
+ for task_name, task_handler in task_handlers.items():
+ self.task_handlers[task_name] = TRANSFORMS.build(task_handler)
+
+ def transform(self, results: dict) -> dict:
+ """Method to pack the input data.
+
+ result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3},
+ 'img': array([[[ 0, 0, 0])
+ """
+ packed_results = dict()
+ results = results.copy()
+
+ if self.input_key in results:
+ input_ = results[self.input_key]
+ packed_results['inputs'] = PackInputs.format_input(input_)
+
+ task_results = defaultdict(dict)
+ for field in self.multi_task_fields:
+ if field in results:
+ value = results.pop(field)
+ for k, v in value.items():
+ task_results[k].update({field: v})
+
+ data_sample = MultiTaskDataSample()
+ for task_name, task_result in task_results.items():
+ task_handler = self.task_handlers[task_name]
+ task_pack_result = task_handler({**results, **task_result})
+ data_sample.set_field(task_pack_result['data_samples'], task_name)
+
+ packed_results['data_samples'] = data_sample
+ return packed_results
+
+ def __repr__(self):
+ repr = self.__class__.__name__
+ task_handlers = ', '.join(
+ f"'{name}': {handler.__class__.__name__}"
+ for name, handler in self.task_handlers.items())
+ repr += f'(multi_task_fields={self.multi_task_fields}, '
+ repr += f"input_key='{self.input_key}', "
+ repr += f'task_handlers={{{task_handlers}}})'
+ return repr
+
+
+@TRANSFORMS.register_module()
+class Transpose(BaseTransform):
+ """Transpose numpy array.
+
+ **Required Keys:**
+
+ - ``*keys``
+
+ **Modified Keys:**
+
+ - ``*keys``
+
+ Args:
+ keys (List[str]): The fields to convert to tensor.
+ order (List[int]): The output dimensions order.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def transform(self, results):
+ """Method to transpose array."""
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL'))
+class NumpyToPIL(BaseTransform):
+ """Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
+
+ **Required Keys:**
+
+ - ``img``
+
+ **Modified Keys:**
+
+ - ``img``
+
+ Args:
+ to_rgb (bool): Whether to convert img to rgb. Defaults to True.
+ """
+
+ def __init__(self, to_rgb: bool = False) -> None:
+ self.to_rgb = to_rgb
+
+ def transform(self, results: dict) -> dict:
+ """Method to convert images to :obj:`PIL.Image.Image`."""
+ img = results['img']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img
+
+ results['img'] = Image.fromarray(img)
+ return results
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + f'(to_rgb={self.to_rgb})'
+
+
+@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy'))
+class PILToNumpy(BaseTransform):
+ """Convert img to :obj:`numpy.ndarray`.
+
+ **Required Keys:**
+
+ - ``img``
+
+ **Modified Keys:**
+
+ - ``img``
+
+ Args:
+ to_bgr (bool): Whether to convert img to rgb. Defaults to True.
+ dtype (str, optional): The dtype of the converted numpy array.
+ Defaults to None.
+ """
+
+ def __init__(self, to_bgr: bool = False, dtype=None) -> None:
+ self.to_bgr = to_bgr
+ self.dtype = dtype
+
+ def transform(self, results: dict) -> dict:
+ """Method to convert img to :obj:`numpy.ndarray`."""
+ img = np.array(results['img'], dtype=self.dtype)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img
+
+ results['img'] = img
+ return results
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + \
+ f'(to_bgr={self.to_bgr}, dtype={self.dtype})'
+
+
+@TRANSFORMS.register_module()
+class Collect(BaseTransform):
+ """Collect and only reserve the specified fields.
+
+ **Required Keys:**
+
+ - ``*keys``
+
+ **Deleted Keys:**
+
+ All keys except those in the argument ``*keys``.
+
+ Args:
+ keys (Sequence[str]): The keys of the fields to be collected.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def transform(self, results):
+ data = {}
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbc04da74ec8032690f1a59126e09a323e9a0036
--- /dev/null
+++ b/mmpretrain/datasets/transforms/processing.py
@@ -0,0 +1,1742 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import math
+import numbers
+import re
+import string
+from enum import EnumMeta
+from numbers import Number
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import mmcv
+import mmengine
+import numpy as np
+import torchvision
+from mmcv.transforms import BaseTransform
+from mmcv.transforms.utils import cache_randomness
+from torchvision.transforms.transforms import InterpolationMode
+
+from mmpretrain.registry import TRANSFORMS
+
+try:
+ import albumentations
+except ImportError:
+ albumentations = None
+
+
+def _str_to_torch_dtype(t: str):
+ """mapping str format dtype to torch.dtype."""
+ import torch # noqa: F401,F403
+ return eval(f'torch.{t}')
+
+
+def _interpolation_modes_from_str(t: str):
+ """mapping str format to Interpolation."""
+ t = t.lower()
+ inverse_modes_mapping = {
+ 'nearest': InterpolationMode.NEAREST,
+ 'bilinear': InterpolationMode.BILINEAR,
+ 'bicubic': InterpolationMode.BICUBIC,
+ 'box': InterpolationMode.BOX,
+ 'hammimg': InterpolationMode.HAMMING,
+ 'lanczos': InterpolationMode.LANCZOS,
+ }
+ return inverse_modes_mapping[t]
+
+
+class TorchVisonTransformWrapper:
+
+ def __init__(self, transform, *args, **kwargs):
+ if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
+ str):
+ kwargs['interpolation'] = _interpolation_modes_from_str(
+ kwargs['interpolation'])
+ if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
+ kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])
+ self.t = transform(*args, **kwargs)
+
+ def __call__(self, results):
+ results['img'] = self.t(results['img'])
+ return results
+
+ def __repr__(self) -> str:
+ return f'TorchVision{repr(self.t)}'
+
+
+def register_vision_transforms() -> List[str]:
+ """Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS``
+ registry.
+
+ Returns:
+ List[str]: A list of registered transforms' name.
+ """
+ vision_transforms = []
+ for module_name in dir(torchvision.transforms):
+ if not re.match('[A-Z]', module_name):
+ # must startswith a capital letter
+ continue
+ _transform = getattr(torchvision.transforms, module_name)
+ if inspect.isclass(_transform) and callable(
+ _transform) and not isinstance(_transform, (EnumMeta)):
+ from functools import partial
+ TRANSFORMS.register_module(
+ module=partial(
+ TorchVisonTransformWrapper, transform=_transform),
+ name=f'torchvision/{module_name}')
+ vision_transforms.append(f'torchvision/{module_name}')
+ return vision_transforms
+
+
+# register all the transforms in torchvision by using a transform wrapper
+VISION_TRANSFORMS = register_vision_transforms()
+
+
+@TRANSFORMS.register_module()
+class RandomCrop(BaseTransform):
+ """Crop the given Image at a random location.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ Args:
+ crop_size (int | Sequence): Desired output size of the crop. If
+ crop_size is an int instead of sequence like (h, w), a square crop
+ (crop_size, crop_size) is made.
+ padding (int | Sequence, optional): Optional padding on each border
+ of the image. If a sequence of length 4 is provided, it is used to
+ pad left, top, right, bottom borders respectively. If a sequence
+ of length 2 is provided, it is used to pad left/right, top/bottom
+ borders, respectively. Default: None, which means no padding.
+ pad_if_needed (bool): It will pad the image if smaller than the
+ desired size to avoid raising an exception. Since cropping is done
+ after padding, the padding seems to be done at a random offset.
+ Default: False.
+ pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
+ fill. If a tuple of length 3, it is used to pad_val R, G, B
+ channels respectively. Default: 0.
+ padding_mode (str): Type of padding. Defaults to "constant". Should
+ be one of the following:
+
+ - ``constant``: Pads with a constant value, this value is specified
+ with pad_val.
+ - ``edge``: pads with the last value at the edge of the image.
+ - ``reflect``: Pads with reflection of image without repeating the
+ last value on the edge. For example, padding [1, 2, 3, 4]
+ with 2 elements on both sides in reflect mode will result
+ in [3, 2, 1, 2, 3, 4, 3, 2].
+ - ``symmetric``: Pads with reflection of image repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with
+ 2 elements on both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3].
+ """
+
+ def __init__(self,
+ crop_size: Union[Sequence, int],
+ padding: Optional[Union[Sequence, int]] = None,
+ pad_if_needed: bool = False,
+ pad_val: Union[Number, Sequence[Number]] = 0,
+ padding_mode: str = 'constant'):
+ if isinstance(crop_size, Sequence):
+ assert len(crop_size) == 2
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ else:
+ assert crop_size > 0
+ self.crop_size = (crop_size, crop_size)
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+ self.padding = padding
+ self.pad_if_needed = pad_if_needed
+ self.pad_val = pad_val
+ self.padding_mode = padding_mode
+
+ @cache_randomness
+ def rand_crop_params(self, img: np.ndarray):
+ """Get parameters for ``crop`` for a random crop.
+
+ Args:
+ img (ndarray): Image to be cropped.
+
+ Returns:
+ tuple: Params (offset_h, offset_w, target_h, target_w) to be
+ passed to ``crop`` for random crop.
+ """
+ h, w = img.shape[:2]
+ target_h, target_w = self.crop_size
+ if w == target_w and h == target_h:
+ return 0, 0, h, w
+ elif w < target_w or h < target_h:
+ target_w = min(w, target_w)
+ target_h = min(w, target_h)
+
+ offset_h = np.random.randint(0, h - target_h + 1)
+ offset_w = np.random.randint(0, w - target_w + 1)
+
+ return offset_h, offset_w, target_h, target_w
+
+ def transform(self, results: dict) -> dict:
+ """Transform function to randomly crop images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape'
+ key in result dict is updated according to crop size.
+ """
+ img = results['img']
+ if self.padding is not None:
+ img = mmcv.impad(img, padding=self.padding, pad_val=self.pad_val)
+
+ # pad img if needed
+ if self.pad_if_needed:
+ h_pad = math.ceil(max(0, self.crop_size[0] - img.shape[0]) / 2)
+ w_pad = math.ceil(max(0, self.crop_size[1] - img.shape[1]) / 2)
+
+ img = mmcv.impad(
+ img,
+ padding=(w_pad, h_pad, w_pad, h_pad),
+ pad_val=self.pad_val,
+ padding_mode=self.padding_mode)
+
+ offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
+ img = mmcv.imcrop(
+ img,
+ np.array([
+ offset_w,
+ offset_h,
+ offset_w + target_w - 1,
+ offset_h + target_h - 1,
+ ]))
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
+ repr_str += f', padding={self.padding}'
+ repr_str += f', pad_if_needed={self.pad_if_needed}'
+ repr_str += f', pad_val={self.pad_val}'
+ repr_str += f', padding_mode={self.padding_mode})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomResizedCrop(BaseTransform):
+ """Crop the given image to random scale and aspect ratio.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a
+ random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
+ is made. This crop is finally resized to given size.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ Args:
+ scale (sequence | int): Desired output scale of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ crop_ratio_range (tuple): Range of the random size of the cropped
+ image compared to the original image. Defaults to (0.08, 1.0).
+ aspect_ratio_range (tuple): Range of the random aspect ratio of the
+ cropped image compared to the original image.
+ Defaults to (3. / 4., 4. / 3.).
+ max_attempts (int): Maximum number of attempts before falling back to
+ Central Crop. Defaults to 10.
+ interpolation (str): Interpolation method, accepted values are
+ 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
+ 'bilinear'.
+ backend (str): The image resize backend type, accepted values are
+ 'cv2' and 'pillow'. Defaults to 'cv2'.
+ """
+
+ def __init__(self,
+ scale: Union[Sequence, int],
+ crop_ratio_range: Tuple[float, float] = (0.08, 1.0),
+ aspect_ratio_range: Tuple[float, float] = (3. / 4., 4. / 3.),
+ max_attempts: int = 10,
+ interpolation: str = 'bilinear',
+ backend: str = 'cv2') -> None:
+ if isinstance(scale, Sequence):
+ assert len(scale) == 2
+ assert scale[0] > 0 and scale[1] > 0
+ self.scale = scale
+ else:
+ assert scale > 0
+ self.scale = (scale, scale)
+ if (crop_ratio_range[0] > crop_ratio_range[1]) or (
+ aspect_ratio_range[0] > aspect_ratio_range[1]):
+ raise ValueError(
+ 'range should be of kind (min, max). '
+ f'But received crop_ratio_range {crop_ratio_range} '
+ f'and aspect_ratio_range {aspect_ratio_range}.')
+ assert isinstance(max_attempts, int) and max_attempts >= 0, \
+ 'max_attempts mush be int and no less than 0.'
+ assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
+ 'lanczos')
+
+ self.crop_ratio_range = crop_ratio_range
+ self.aspect_ratio_range = aspect_ratio_range
+ self.max_attempts = max_attempts
+ self.interpolation = interpolation
+ self.backend = backend
+
+ @cache_randomness
+ def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]:
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (ndarray): Image to be cropped.
+
+ Returns:
+ tuple: Params (offset_h, offset_w, target_h, target_w) to be
+ passed to `crop` for a random sized crop.
+ """
+ h, w = img.shape[:2]
+ area = h * w
+
+ for _ in range(self.max_attempts):
+ target_area = np.random.uniform(*self.crop_ratio_range) * area
+ log_ratio = (math.log(self.aspect_ratio_range[0]),
+ math.log(self.aspect_ratio_range[1]))
+ aspect_ratio = math.exp(np.random.uniform(*log_ratio))
+ target_w = int(round(math.sqrt(target_area * aspect_ratio)))
+ target_h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < target_w <= w and 0 < target_h <= h:
+ offset_h = np.random.randint(0, h - target_h + 1)
+ offset_w = np.random.randint(0, w - target_w + 1)
+
+ return offset_h, offset_w, target_h, target_w
+
+ # Fallback to central crop
+ in_ratio = float(w) / float(h)
+ if in_ratio < min(self.aspect_ratio_range):
+ target_w = w
+ target_h = int(round(target_w / min(self.aspect_ratio_range)))
+ elif in_ratio > max(self.aspect_ratio_range):
+ target_h = h
+ target_w = int(round(target_h * max(self.aspect_ratio_range)))
+ else: # whole image
+ target_w = w
+ target_h = h
+ offset_h = (h - target_h) // 2
+ offset_w = (w - target_w) // 2
+ return offset_h, offset_w, target_h, target_w
+
+ def transform(self, results: dict) -> dict:
+ """Transform function to randomly resized crop images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly resized cropped results, 'img_shape'
+ key in result dict is updated according to crop size.
+ """
+ img = results['img']
+ offset_h, offset_w, target_h, target_w = self.rand_crop_params(img)
+ img = mmcv.imcrop(
+ img,
+ bboxes=np.array([
+ offset_w, offset_h, offset_w + target_w - 1,
+ offset_h + target_h - 1
+ ]))
+ img = mmcv.imresize(
+ img,
+ tuple(self.scale[::-1]),
+ interpolation=self.interpolation,
+ backend=self.backend)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__ + f'(scale={self.scale}'
+ repr_str += ', crop_ratio_range='
+ repr_str += f'{tuple(round(s, 4) for s in self.crop_ratio_range)}'
+ repr_str += ', aspect_ratio_range='
+ repr_str += f'{tuple(round(r, 4) for r in self.aspect_ratio_range)}'
+ repr_str += f', max_attempts={self.max_attempts}'
+ repr_str += f', interpolation={self.interpolation}'
+ repr_str += f', backend={self.backend})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class EfficientNetRandomCrop(RandomResizedCrop):
+ """EfficientNet style RandomResizedCrop.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ Args:
+ scale (int): Desired output scale of the crop. Only int size is
+ accepted, a square crop (size, size) is made.
+ min_covered (Number): Minimum ratio of the cropped area to the original
+ area. Defaults to 0.1.
+ crop_padding (int): The crop padding parameter in efficientnet style
+ center crop. Defaults to 32.
+ crop_ratio_range (tuple): Range of the random size of the cropped
+ image compared to the original image. Defaults to (0.08, 1.0).
+ aspect_ratio_range (tuple): Range of the random aspect ratio of the
+ cropped image compared to the original image.
+ Defaults to (3. / 4., 4. / 3.).
+ max_attempts (int): Maximum number of attempts before falling back to
+ Central Crop. Defaults to 10.
+ interpolation (str): Interpolation method, accepted values are
+ 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
+ 'bicubic'.
+ backend (str): The image resize backend type, accepted values are
+ 'cv2' and 'pillow'. Defaults to 'cv2'.
+ """
+
+ def __init__(self,
+ scale: int,
+ min_covered: float = 0.1,
+ crop_padding: int = 32,
+ interpolation: str = 'bicubic',
+ **kwarg):
+ assert isinstance(scale, int)
+ super().__init__(scale, interpolation=interpolation, **kwarg)
+ assert min_covered >= 0, 'min_covered should be no less than 0.'
+ assert crop_padding >= 0, 'crop_padding should be no less than 0.'
+
+ self.min_covered = min_covered
+ self.crop_padding = crop_padding
+
+ # https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py # noqa
+ @cache_randomness
+ def rand_crop_params(self, img: np.ndarray) -> Tuple[int, int, int, int]:
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (ndarray): Image to be cropped.
+
+ Returns:
+ tuple: Params (offset_h, offset_w, target_h, target_w) to be
+ passed to `crop` for a random sized crop.
+ """
+ h, w = img.shape[:2]
+ area = h * w
+ min_target_area = self.crop_ratio_range[0] * area
+ max_target_area = self.crop_ratio_range[1] * area
+
+ for _ in range(self.max_attempts):
+ aspect_ratio = np.random.uniform(*self.aspect_ratio_range)
+ min_target_h = int(
+ round(math.sqrt(min_target_area / aspect_ratio)))
+ max_target_h = int(
+ round(math.sqrt(max_target_area / aspect_ratio)))
+
+ if max_target_h * aspect_ratio > w:
+ max_target_h = int((w + 0.5 - 1e-7) / aspect_ratio)
+ if max_target_h * aspect_ratio > w:
+ max_target_h -= 1
+
+ max_target_h = min(max_target_h, h)
+ min_target_h = min(max_target_h, min_target_h)
+
+ # slightly differs from tf implementation
+ target_h = int(
+ round(np.random.uniform(min_target_h, max_target_h)))
+ target_w = int(round(target_h * aspect_ratio))
+ target_area = target_h * target_w
+
+ # slight differs from tf. In tf, if target_area > max_target_area,
+ # area will be recalculated
+ if (target_area < min_target_area or target_area > max_target_area
+ or target_w > w or target_h > h
+ or target_area < self.min_covered * area):
+ continue
+
+ offset_h = np.random.randint(0, h - target_h + 1)
+ offset_w = np.random.randint(0, w - target_w + 1)
+
+ return offset_h, offset_w, target_h, target_w
+
+ # Fallback to central crop
+ img_short = min(h, w)
+ crop_size = self.scale[0] / (self.scale[0] +
+ self.crop_padding) * img_short
+
+ offset_h = max(0, int(round((h - crop_size) / 2.)))
+ offset_w = max(0, int(round((w - crop_size) / 2.)))
+ return offset_h, offset_w, crop_size, crop_size
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = super().__repr__()[:-1]
+ repr_str += f', min_covered={self.min_covered}'
+ repr_str += f', crop_padding={self.crop_padding})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomErasing(BaseTransform):
+ """Randomly selects a rectangle region in an image and erase pixels.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+
+ Args:
+ erase_prob (float): Probability that image will be randomly erased.
+ Default: 0.5
+ min_area_ratio (float): Minimum erased area / input image area
+ Default: 0.02
+ max_area_ratio (float): Maximum erased area / input image area
+ Default: 0.4
+ aspect_range (sequence | float): Aspect ratio range of erased area.
+ if float, it will be converted to (aspect_ratio, 1/aspect_ratio)
+ Default: (3/10, 10/3)
+ mode (str): Fill method in erased area, can be:
+
+ - const (default): All pixels are assign with the same value.
+ - rand: each pixel is assigned with a random value in [0, 255]
+
+ fill_color (sequence | Number): Base color filled in erased area.
+ Defaults to (128, 128, 128).
+ fill_std (sequence | Number, optional): If set and ``mode`` is 'rand',
+ fill erased area with random color from normal distribution
+ (mean=fill_color, std=fill_std); If not set, fill erased area with
+ random color from uniform distribution (0~255). Defaults to None.
+
+ Note:
+ See `Random Erasing Data Augmentation
+ `_
+
+ This paper provided 4 modes: RE-R, RE-M, RE-0, RE-255, and use RE-M as
+ default. The config of these 4 modes are:
+
+ - RE-R: RandomErasing(mode='rand')
+ - RE-M: RandomErasing(mode='const', fill_color=(123.67, 116.3, 103.5))
+ - RE-0: RandomErasing(mode='const', fill_color=0)
+ - RE-255: RandomErasing(mode='const', fill_color=255)
+ """
+
+ def __init__(self,
+ erase_prob=0.5,
+ min_area_ratio=0.02,
+ max_area_ratio=0.4,
+ aspect_range=(3 / 10, 10 / 3),
+ mode='const',
+ fill_color=(128, 128, 128),
+ fill_std=None):
+ assert isinstance(erase_prob, float) and 0. <= erase_prob <= 1.
+ assert isinstance(min_area_ratio, float) and 0. <= min_area_ratio <= 1.
+ assert isinstance(max_area_ratio, float) and 0. <= max_area_ratio <= 1.
+ assert min_area_ratio <= max_area_ratio, \
+ 'min_area_ratio should be smaller than max_area_ratio'
+ if isinstance(aspect_range, float):
+ aspect_range = min(aspect_range, 1 / aspect_range)
+ aspect_range = (aspect_range, 1 / aspect_range)
+ assert isinstance(aspect_range, Sequence) and len(aspect_range) == 2 \
+ and all(isinstance(x, float) for x in aspect_range), \
+ 'aspect_range should be a float or Sequence with two float.'
+ assert all(x > 0 for x in aspect_range), \
+ 'aspect_range should be positive.'
+ assert aspect_range[0] <= aspect_range[1], \
+ 'In aspect_range (min, max), min should be smaller than max.'
+ assert mode in ['const', 'rand'], \
+ 'Please select `mode` from ["const", "rand"].'
+ if isinstance(fill_color, Number):
+ fill_color = [fill_color] * 3
+ assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \
+ and all(isinstance(x, Number) for x in fill_color), \
+ 'fill_color should be a float or Sequence with three int.'
+ if fill_std is not None:
+ if isinstance(fill_std, Number):
+ fill_std = [fill_std] * 3
+ assert isinstance(fill_std, Sequence) and len(fill_std) == 3 \
+ and all(isinstance(x, Number) for x in fill_std), \
+ 'fill_std should be a float or Sequence with three int.'
+
+ self.erase_prob = erase_prob
+ self.min_area_ratio = min_area_ratio
+ self.max_area_ratio = max_area_ratio
+ self.aspect_range = aspect_range
+ self.mode = mode
+ self.fill_color = fill_color
+ self.fill_std = fill_std
+
+ def _fill_pixels(self, img, top, left, h, w):
+ """Fill pixels to the patch of image."""
+ if self.mode == 'const':
+ patch = np.empty((h, w, 3), dtype=np.uint8)
+ patch[:, :] = np.array(self.fill_color, dtype=np.uint8)
+ elif self.fill_std is None:
+ # Uniform distribution
+ patch = np.random.uniform(0, 256, (h, w, 3)).astype(np.uint8)
+ else:
+ # Normal distribution
+ patch = np.random.normal(self.fill_color, self.fill_std, (h, w, 3))
+ patch = np.clip(patch.astype(np.int32), 0, 255).astype(np.uint8)
+
+ img[top:top + h, left:left + w] = patch
+ return img
+
+ @cache_randomness
+ def random_disable(self):
+ """Randomly disable the transform."""
+ return np.random.rand() > self.erase_prob
+
+ @cache_randomness
+ def random_patch(self, img_h, img_w):
+ """Randomly generate patch the erase."""
+ # convert the aspect ratio to log space to equally handle width and
+ # height.
+ log_aspect_range = np.log(
+ np.array(self.aspect_range, dtype=np.float32))
+ aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
+ area = img_h * img_w
+ area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
+
+ h = min(int(round(np.sqrt(area * aspect_ratio))), img_h)
+ w = min(int(round(np.sqrt(area / aspect_ratio))), img_w)
+ top = np.random.randint(0, img_h - h) if img_h > h else 0
+ left = np.random.randint(0, img_w - w) if img_w > w else 0
+ return top, left, h, w
+
+ def transform(self, results):
+ """
+ Args:
+ results (dict): Results dict from pipeline
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if self.random_disable():
+ return results
+
+ img = results['img']
+ img_h, img_w = img.shape[:2]
+
+ img = self._fill_pixels(img, *self.random_patch(img_h, img_w))
+
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(erase_prob={self.erase_prob}, '
+ repr_str += f'min_area_ratio={self.min_area_ratio}, '
+ repr_str += f'max_area_ratio={self.max_area_ratio}, '
+ repr_str += f'aspect_range={self.aspect_range}, '
+ repr_str += f'mode={self.mode}, '
+ repr_str += f'fill_color={self.fill_color}, '
+ repr_str += f'fill_std={self.fill_std})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class EfficientNetCenterCrop(BaseTransform):
+ r"""EfficientNet style center crop.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ Args:
+ crop_size (int): Expected size after cropping with the format
+ of (h, w).
+ crop_padding (int): The crop padding parameter in efficientnet style
+ center crop. Defaults to 32.
+ interpolation (str): Interpolation method, accepted values are
+ 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
+ ``efficientnet_style`` is True. Defaults to 'bicubic'.
+ backend (str): The image resize backend type, accepted values are
+ `cv2` and `pillow`. Only valid if efficientnet style is True.
+ Defaults to `cv2`.
+ Notes:
+ - If the image is smaller than the crop size, return the original
+ image.
+ - The pipeline will be to first
+ to perform the center crop with the ``crop_size_`` as:
+
+ .. math::
+
+ \text{crop_size_} = \frac{\text{crop_size}}{\text{crop_size} +
+ \text{crop_padding}} \times \text{short_edge}
+
+ And then the pipeline resizes the img to the input crop size.
+ """
+
+ def __init__(self,
+ crop_size: int,
+ crop_padding: int = 32,
+ interpolation: str = 'bicubic',
+ backend: str = 'cv2'):
+ assert isinstance(crop_size, int)
+ assert crop_size > 0
+ assert crop_padding >= 0
+ assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
+ 'lanczos')
+
+ self.crop_size = crop_size
+ self.crop_padding = crop_padding
+ self.interpolation = interpolation
+ self.backend = backend
+
+ def transform(self, results: dict) -> dict:
+ """Transform function to randomly resized crop images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: EfficientNet style center cropped results, 'img_shape'
+ key in result dict is updated according to crop size.
+ """
+ img = results['img']
+ h, w = img.shape[:2]
+
+ # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa
+ img_short = min(h, w)
+ crop_size = self.crop_size / (self.crop_size +
+ self.crop_padding) * img_short
+
+ offset_h = max(0, int(round((h - crop_size) / 2.)))
+ offset_w = max(0, int(round((w - crop_size) / 2.)))
+
+ # crop the image
+ img = mmcv.imcrop(
+ img,
+ bboxes=np.array([
+ offset_w, offset_h, offset_w + crop_size - 1,
+ offset_h + crop_size - 1
+ ]))
+ # resize image
+ img = mmcv.imresize(
+ img, (self.crop_size, self.crop_size),
+ interpolation=self.interpolation,
+ backend=self.backend)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
+ repr_str += f', crop_padding={self.crop_padding}'
+ repr_str += f', interpolation={self.interpolation}'
+ repr_str += f', backend={self.backend})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class ResizeEdge(BaseTransform):
+ """Resize images along the specified edge.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ **Added Keys:**
+
+ - scale
+ - scale_factor
+
+ Args:
+ scale (int): The edge scale to resizing.
+ edge (str): The edge to resize. Defaults to 'short'.
+ backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results.
+ Defaults to 'cv2'.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ Defaults to 'bilinear'.
+ """
+
+ def __init__(self,
+ scale: int,
+ edge: str = 'short',
+ backend: str = 'cv2',
+ interpolation: str = 'bilinear') -> None:
+ allow_edges = ['short', 'long', 'width', 'height']
+ assert edge in allow_edges, \
+ f'Invalid edge "{edge}", please specify from {allow_edges}.'
+ self.edge = edge
+ self.scale = scale
+ self.backend = backend
+ self.interpolation = interpolation
+
+ def _resize_img(self, results: dict) -> None:
+ """Resize images with ``results['scale']``."""
+
+ img, w_scale, h_scale = mmcv.imresize(
+ results['img'],
+ results['scale'],
+ interpolation=self.interpolation,
+ return_scale=True,
+ backend=self.backend)
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['scale'] = img.shape[:2][::-1]
+ results['scale_factor'] = (w_scale, h_scale)
+
+ def transform(self, results: Dict) -> Dict:
+ """Transform function to resize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img', 'scale', 'scale_factor',
+ 'img_shape' keys are updated in result dict.
+ """
+ assert 'img' in results, 'No `img` field in the input.'
+
+ h, w = results['img'].shape[:2]
+ if any([
+ # conditions to resize the width
+ self.edge == 'short' and w < h,
+ self.edge == 'long' and w > h,
+ self.edge == 'width',
+ ]):
+ width = self.scale
+ height = int(self.scale * h / w)
+ else:
+ height = self.scale
+ width = int(self.scale * w / h)
+ results['scale'] = (width, height)
+
+ self._resize_img(results)
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(scale={self.scale}, '
+ repr_str += f'edge={self.edge}, '
+ repr_str += f'backend={self.backend}, '
+ repr_str += f'interpolation={self.interpolation})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class ColorJitter(BaseTransform):
+ """Randomly change the brightness, contrast and saturation of an image.
+
+ Modified from
+ https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
+ Licensed under the BSD 3-Clause License.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+
+ Args:
+ brightness (float | Sequence[float] (min, max)): How much to jitter
+ brightness. brightness_factor is chosen uniformly from
+ ``[max(0, 1 - brightness), 1 + brightness]`` or the given
+ ``[min, max]``. Should be non negative numbers. Defaults to 0.
+ contrast (float | Sequence[float] (min, max)): How much to jitter
+ contrast. contrast_factor is chosen uniformly from
+ ``[max(0, 1 - contrast), 1 + contrast]`` or the given
+ ``[min, max]``. Should be non negative numbers. Defaults to 0.
+ saturation (float | Sequence[float] (min, max)): How much to jitter
+ saturation. saturation_factor is chosen uniformly from
+ ``[max(0, 1 - saturation), 1 + saturation]`` or the given
+ ``[min, max]``. Should be non negative numbers. Defaults to 0.
+ hue (float | Sequence[float] (min, max)): How much to jitter hue.
+ hue_factor is chosen uniformly from ``[-hue, hue]`` (0 <= hue
+ <= 0.5) or the given ``[min, max]`` (-0.5 <= min <= max <= 0.5).
+ Defaults to 0.
+ backend (str): The backend to operate the image. Defaults to 'pillow'
+ """
+
+ def __init__(self,
+ brightness: Union[float, Sequence[float]] = 0.,
+ contrast: Union[float, Sequence[float]] = 0.,
+ saturation: Union[float, Sequence[float]] = 0.,
+ hue: Union[float, Sequence[float]] = 0.,
+ backend='pillow'):
+ self.brightness = self._set_range(brightness, 'brightness')
+ self.contrast = self._set_range(contrast, 'contrast')
+ self.saturation = self._set_range(saturation, 'saturation')
+ self.hue = self._set_range(hue, 'hue', center=0, bound=(-0.5, 0.5))
+ self.backend = backend
+
+ def _set_range(self, value, name, center=1, bound=(0, float('inf'))):
+ """Set the range of magnitudes."""
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError(
+ f'If {name} is a single number, it must be non negative.')
+ value = (center - float(value), center + float(value))
+
+ if isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ value = np.clip(value, bound[0], bound[1])
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ logger.warning(f'ColorJitter {name} values exceed the bound '
+ f'{bound}, clipped to the bound.')
+ else:
+ raise TypeError(f'{name} should be a single number '
+ 'or a list/tuple with length 2.')
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ else:
+ value = tuple(value)
+
+ return value
+
+ @cache_randomness
+ def _rand_params(self):
+ """Get random parameters including magnitudes and indices of
+ transforms."""
+ trans_inds = np.random.permutation(4)
+ b, c, s, h = (None, ) * 4
+
+ if self.brightness is not None:
+ b = np.random.uniform(self.brightness[0], self.brightness[1])
+ if self.contrast is not None:
+ c = np.random.uniform(self.contrast[0], self.contrast[1])
+ if self.saturation is not None:
+ s = np.random.uniform(self.saturation[0], self.saturation[1])
+ if self.hue is not None:
+ h = np.random.uniform(self.hue[0], self.hue[1])
+
+ return trans_inds, b, c, s, h
+
+ def transform(self, results: Dict) -> Dict:
+ """Transform function to resize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: ColorJitter results, 'img' key is updated in result dict.
+ """
+ img = results['img']
+ trans_inds, brightness, contrast, saturation, hue = self._rand_params()
+
+ for index in trans_inds:
+ if index == 0 and brightness is not None:
+ img = mmcv.adjust_brightness(
+ img, brightness, backend=self.backend)
+ elif index == 1 and contrast is not None:
+ img = mmcv.adjust_contrast(img, contrast, backend=self.backend)
+ elif index == 2 and saturation is not None:
+ img = mmcv.adjust_color(
+ img, alpha=saturation, backend=self.backend)
+ elif index == 3 and hue is not None:
+ img = mmcv.adjust_hue(img, hue, backend=self.backend)
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(brightness={self.brightness}, '
+ repr_str += f'contrast={self.contrast}, '
+ repr_str += f'saturation={self.saturation}, '
+ repr_str += f'hue={self.hue})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class Lighting(BaseTransform):
+ """Adjust images lighting using AlexNet-style PCA jitter.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+
+ Args:
+ eigval (Sequence[float]): the eigenvalue of the convariance matrix
+ of pixel values, respectively.
+ eigvec (list[list]): the eigenvector of the convariance matrix of
+ pixel values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1.
+ to_rgb (bool): Whether to convert img to rgb. Defaults to False.
+ """
+
+ def __init__(self,
+ eigval: Sequence[float],
+ eigvec: Sequence[float],
+ alphastd: float = 0.1,
+ to_rgb: bool = False):
+ assert isinstance(eigval, Sequence), \
+ f'eigval must be Sequence, got {type(eigval)} instead.'
+ assert isinstance(eigvec, Sequence), \
+ f'eigvec must be Sequence, got {type(eigvec)} instead.'
+ for vec in eigvec:
+ assert isinstance(vec, Sequence) and len(vec) == len(eigvec[0]), \
+ 'eigvec must contains lists with equal length.'
+ assert isinstance(alphastd, float), 'alphastd should be of type ' \
+ f'float or int, got {type(alphastd)} instead.'
+
+ self.eigval = np.array(eigval)
+ self.eigvec = np.array(eigvec)
+ self.alphastd = alphastd
+ self.to_rgb = to_rgb
+
+ def transform(self, results: Dict) -> Dict:
+ """Transform function to resize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Lightinged results, 'img' key is updated in result dict.
+ """
+ assert 'img' in results, 'No `img` field in the input.'
+
+ img = results['img']
+ img_lighting = mmcv.adjust_lighting(
+ img,
+ self.eigval,
+ self.eigvec,
+ alphastd=self.alphastd,
+ to_rgb=self.to_rgb)
+ results['img'] = img_lighting.astype(img.dtype)
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(eigval={self.eigval.tolist()}, '
+ repr_str += f'eigvec={self.eigvec.tolist()}, '
+ repr_str += f'alphastd={self.alphastd}, '
+ repr_str += f'to_rgb={self.to_rgb})'
+ return repr_str
+
+
+# 'Albu' is used in previous versions of mmpretrain, here is for compatibility
+# users can use both 'Albumentations' and 'Albu'.
+@TRANSFORMS.register_module(['Albumentations', 'Albu'])
+class Albumentations(BaseTransform):
+ """Wrapper to use augmentation from albumentations library.
+
+ **Required Keys:**
+
+ - img
+
+ **Modified Keys:**
+
+ - img
+ - img_shape
+
+ Adds custom transformations from albumentations library.
+ More details can be found in
+ `Albumentations `_.
+ An example of ``transforms`` is as followed:
+
+ .. code-block::
+
+ [
+ dict(
+ type='ShiftScaleRotate',
+ shift_limit=0.0625,
+ scale_limit=0.0,
+ rotate_limit=0,
+ interpolation=1,
+ p=0.5),
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+
+ Args:
+ transforms (List[Dict]): List of albumentations transform configs.
+ keymap (Optional[Dict]): Mapping of mmpretrain to albumentations
+ fields, in format {'input key':'albumentation-style key'}.
+ Defaults to None.
+
+ Example:
+ >>> import mmcv
+ >>> from mmpretrain.datasets import Albumentations
+ >>> transforms = [
+ ... dict(
+ ... type='ShiftScaleRotate',
+ ... shift_limit=0.0625,
+ ... scale_limit=0.0,
+ ... rotate_limit=0,
+ ... interpolation=1,
+ ... p=0.5),
+ ... dict(
+ ... type='RandomBrightnessContrast',
+ ... brightness_limit=[0.1, 0.3],
+ ... contrast_limit=[0.1, 0.3],
+ ... p=0.2),
+ ... dict(type='ChannelShuffle', p=0.1),
+ ... dict(
+ ... type='OneOf',
+ ... transforms=[
+ ... dict(type='Blur', blur_limit=3, p=1.0),
+ ... dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ... ],
+ ... p=0.1),
+ ... ]
+ >>> albu = Albumentations(transforms)
+ >>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
+ >>> data = albu(data)
+ >>> print(data['img'].shape)
+ (375, 500, 3)
+ """
+
+ def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+ else:
+ from albumentations import Compose as albu_Compose
+
+ assert isinstance(transforms, list), 'transforms must be a list.'
+ if keymap is not None:
+ assert isinstance(keymap, dict), 'keymap must be None or a dict. '
+
+ self.transforms = transforms
+
+ self.aug = albu_Compose(
+ [self.albu_builder(t) for t in self.transforms])
+
+ if not keymap:
+ self.keymap_to_albu = dict(img='image')
+ else:
+ self.keymap_to_albu = keymap
+ self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+ def albu_builder(self, cfg: Dict):
+ """Import a module from albumentations.
+
+ It inherits some of :func:`build_from_cfg` logic.
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ Returns:
+ obj: The constructed object.
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg, 'each item in ' \
+ "transforms must be a dict with keyword 'type'."
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmengine.is_str(obj_type):
+ obj_cls = getattr(albumentations, obj_type)
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(transform)
+ for transform in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ @staticmethod
+ def mapper(d, keymap):
+ """Dictionary mapper.
+
+ Renames keys according to keymap provided.
+ Args:
+ d (dict): old dict
+ keymap (dict): {'old_key':'new_key'}
+ Returns:
+ dict: new dict.
+ """
+
+ updated_dict = {}
+ for k, v in zip(d.keys(), d.values()):
+ new_k = keymap.get(k, k)
+ updated_dict[new_k] = d[k]
+ return updated_dict
+
+ def transform(self, results: Dict) -> Dict:
+ """Transform function to perform albumentations transforms.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Transformed results, 'img' and 'img_shape' keys are
+ updated in result dict.
+ """
+ assert 'img' in results, 'No `img` field in the input.'
+
+ # dict to albumentations format
+ results = self.mapper(results, self.keymap_to_albu)
+ results = self.aug(**results)
+
+ # back to the original format
+ results = self.mapper(results, self.keymap_back)
+ results['img_shape'] = results['img'].shape[:2]
+
+ return results
+
+ def __repr__(self):
+ """Print the basic information of the transform.
+
+ Returns:
+ str: Formatted string.
+ """
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={repr(self.transforms)})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class SimMIMMaskGenerator(BaseTransform):
+ """Generate random block mask for each Image.
+
+ **Added Keys**:
+
+ - mask
+
+ This module is used in SimMIM to generate masks.
+
+ Args:
+ input_size (int): Size of input image. Defaults to 192.
+ mask_patch_size (int): Size of each block mask. Defaults to 32.
+ model_patch_size (int): Patch size of each token. Defaults to 4.
+ mask_ratio (float): The mask ratio of image. Defaults to 0.6.
+ """
+
+ def __init__(self,
+ input_size: int = 192,
+ mask_patch_size: int = 32,
+ model_patch_size: int = 4,
+ mask_ratio: float = 0.6):
+ self.input_size = input_size
+ self.mask_patch_size = mask_patch_size
+ self.model_patch_size = model_patch_size
+ self.mask_ratio = mask_ratio
+
+ assert self.input_size % self.mask_patch_size == 0
+ assert self.mask_patch_size % self.model_patch_size == 0
+
+ self.rand_size = self.input_size // self.mask_patch_size
+ self.scale = self.mask_patch_size // self.model_patch_size
+
+ self.token_count = self.rand_size**2
+ self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
+
+ def transform(self, results: dict) -> dict:
+ """Method to generate random block mask for each Image in SimMIM.
+
+ Args:
+ results (dict): Result dict from previous pipeline.
+
+ Returns:
+ dict: Result dict with added key ``mask``.
+ """
+ mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
+ mask = np.zeros(self.token_count, dtype=int)
+ mask[mask_idx] = 1
+
+ mask = mask.reshape((self.rand_size, self.rand_size))
+ mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
+
+ results.update({'mask': mask})
+
+ return results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(input_size={self.input_size}, '
+ repr_str += f'mask_patch_size={self.mask_patch_size}, '
+ repr_str += f'model_patch_size={self.model_patch_size}, '
+ repr_str += f'mask_ratio={self.mask_ratio})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class BEiTMaskGenerator(BaseTransform):
+ """Generate mask for image.
+
+ **Added Keys**:
+
+ - mask
+
+ This module is borrowed from
+ https://github.com/microsoft/unilm/tree/master/beit
+
+ Args:
+ input_size (int): The size of input image.
+ num_masking_patches (int): The number of patches to be masked.
+ min_num_patches (int): The minimum number of patches to be masked
+ in the process of generating mask. Defaults to 4.
+ max_num_patches (int, optional): The maximum number of patches to be
+ masked in the process of generating mask. Defaults to None.
+ min_aspect (float): The minimum aspect ratio of mask blocks. Defaults
+ to 0.3.
+ min_aspect (float, optional): The minimum aspect ratio of mask blocks.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ input_size: int,
+ num_masking_patches: int,
+ min_num_patches: int = 4,
+ max_num_patches: Optional[int] = None,
+ min_aspect: float = 0.3,
+ max_aspect: Optional[float] = None) -> None:
+ if not isinstance(input_size, tuple):
+ input_size = (input_size, ) * 2
+ self.height, self.width = input_size
+
+ self.num_patches = self.height * self.width
+
+ self.num_masking_patches = num_masking_patches
+ self.min_num_patches = min_num_patches
+ self.max_num_patches = num_masking_patches if max_num_patches is None \
+ else max_num_patches
+
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+
+ def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
+ """Generate mask recursively.
+
+ Args:
+ mask (np.ndarray): The mask to be generated.
+ max_mask_patches (int): The maximum number of patches to be masked.
+
+ Returns:
+ int: The number of patches masked.
+ """
+ delta = 0
+ for _ in range(10):
+ target_area = np.random.uniform(self.min_num_patches,
+ max_mask_patches)
+ aspect_ratio = math.exp(np.random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < self.width and h < self.height:
+ top = np.random.randint(0, self.height - h)
+ left = np.random.randint(0, self.width - w)
+
+ num_masked = mask[top:top + h, left:left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ for i in range(top, top + h):
+ for j in range(left, left + w):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+ if delta > 0:
+ break
+ return delta
+
+ def transform(self, results: dict) -> dict:
+ """Method to generate random block mask for each Image in BEiT.
+
+ Args:
+ results (dict): Result dict from previous pipeline.
+
+ Returns:
+ dict: Result dict with added key ``mask``.
+ """
+ mask = np.zeros(shape=(self.height, self.width), dtype=int)
+
+ mask_count = 0
+ while mask_count != self.num_masking_patches:
+ max_mask_patches = self.num_masking_patches - mask_count
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
+
+ delta = self._mask(mask, max_mask_patches)
+ mask_count += delta
+ results.update({'mask': mask})
+
+ return results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(height={self.height}, '
+ repr_str += f'width={self.width}, '
+ repr_str += f'num_patches={self.num_patches}, '
+ repr_str += f'num_masking_patches={self.num_masking_patches}, '
+ repr_str += f'min_num_patches={self.min_num_patches}, '
+ repr_str += f'max_num_patches={self.max_num_patches}, '
+ repr_str += f'log_aspect_ratio={self.log_aspect_ratio})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomResizedCropAndInterpolationWithTwoPic(BaseTransform):
+ """Crop the given PIL Image to random size and aspect ratio with random
+ interpolation.
+
+ **Required Keys**:
+
+ - img
+
+ **Modified Keys**:
+
+ - img
+
+ **Added Keys**:
+
+ - target_img
+
+ This module is borrowed from
+ https://github.com/microsoft/unilm/tree/master/beit.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a
+ random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
+ is made. This crop is finally resized to given size. This is popularly used
+ to train the Inception networks. This module first crops the image and
+ resizes the crop to two different sizes.
+
+ Args:
+ size (Union[tuple, int]): Expected output size of each edge of the
+ first image.
+ second_size (Union[tuple, int], optional): Expected output size of each
+ edge of the second image.
+ scale (tuple[float, float]): Range of size of the origin size cropped.
+ Defaults to (0.08, 1.0).
+ ratio (tuple[float, float]): Range of aspect ratio of the origin aspect
+ ratio cropped. Defaults to (3./4., 4./3.).
+ interpolation (str): The interpolation for the first image. Defaults
+ to ``bilinear``.
+ second_interpolation (str): The interpolation for the second image.
+ Defaults to ``lanczos``.
+ """
+
+ def __init__(self,
+ size: Union[tuple, int],
+ second_size=None,
+ scale=(0.08, 1.0),
+ ratio=(3. / 4., 4. / 3.),
+ interpolation='bilinear',
+ second_interpolation='lanczos') -> None:
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if second_size is not None:
+ if isinstance(second_size, tuple):
+ self.second_size = second_size
+ else:
+ self.second_size = (second_size, second_size)
+ else:
+ self.second_size = None
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ ('range should be of kind (min, max)')
+
+ if interpolation == 'random':
+ self.interpolation = ('bilinear', 'bicubic')
+ else:
+ self.interpolation = interpolation
+ self.second_interpolation = second_interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img: np.ndarray, scale: tuple,
+ ratio: tuple) -> Sequence[int]:
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (np.ndarray): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect
+ ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ img_h, img_w = img.shape[:2]
+ area = img_h * img_w
+
+ for _ in range(10):
+ target_area = np.random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(np.random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w < img_w and h < img_h:
+ i = np.random.randint(0, img_h - h)
+ j = np.random.randint(0, img_w - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img_w / img_h
+ if in_ratio < min(ratio):
+ w = img_w
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = img_h
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = img_w
+ h = img_h
+ i = (img_h - h) // 2
+ j = (img_w - w) // 2
+ return i, j, h, w
+
+ def transform(self, results: dict) -> dict:
+ """Crop the given image and resize it to two different sizes.
+
+ This module crops the given image randomly and resize the crop to two
+ different sizes. This is popularly used in BEiT-style masked image
+ modeling, where an off-the-shelf model is used to provide the target.
+
+ Args:
+ results (dict): Results from previous pipeline.
+
+ Returns:
+ dict: Results after applying this transformation.
+ """
+ img = results['img']
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolation = np.random.choice(self.interpolation)
+ else:
+ interpolation = self.interpolation
+ if self.second_size is None:
+ img = img[i:i + h, j:j + w]
+ img = mmcv.imresize(img, self.size, interpolation=interpolation)
+ results.update({'img': img})
+ else:
+ img = img[i:i + h, j:j + w]
+ img_sample = mmcv.imresize(
+ img, self.size, interpolation=interpolation)
+ img_target = mmcv.imresize(
+ img, self.second_size, interpolation=self.second_interpolation)
+ results.update({'img': [img_sample, img_target]})
+ return results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, '
+ repr_str += f'second_size={self.second_size}, '
+ repr_str += f'interpolation={self.interpolation}, '
+ repr_str += f'second_interpolation={self.second_interpolation}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'ratio={self.ratio})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class CleanCaption(BaseTransform):
+ """Clean caption text.
+
+ Remove some useless punctuation for the caption task.
+
+ **Required Keys:**
+
+ - ``*keys``
+
+ **Modified Keys:**
+
+ - ``*keys``
+
+ Args:
+ keys (Sequence[str], optional): The keys of text to be cleaned.
+ Defaults to 'gt_caption'.
+ remove_chars (str): The characters to be removed. Defaults to
+ :py:attr:`string.punctuation`.
+ lowercase (bool): Whether to convert the text to lowercase.
+ Defaults to True.
+ remove_dup_space (bool): Whether to remove duplicated whitespaces.
+ Defaults to True.
+ strip (bool): Whether to remove leading and trailing whitespaces.
+ Defaults to True.
+ """
+
+ def __init__(
+ self,
+ keys='gt_caption',
+ remove_chars=string.punctuation,
+ lowercase=True,
+ remove_dup_space=True,
+ strip=True,
+ ):
+ if isinstance(keys, str):
+ keys = [keys]
+ self.keys = keys
+ self.transtab = str.maketrans({ch: None for ch in remove_chars})
+ self.lowercase = lowercase
+ self.remove_dup_space = remove_dup_space
+ self.strip = strip
+
+ def _clean(self, text):
+ """Perform text cleaning before tokenizer."""
+
+ if self.strip:
+ text = text.strip()
+
+ text = text.translate(self.transtab)
+
+ if self.remove_dup_space:
+ text = re.sub(r'\s{2,}', ' ', text)
+
+ if self.lowercase:
+ text = text.lower()
+
+ return text
+
+ def clean(self, text):
+ """Perform text cleaning before tokenizer."""
+ if isinstance(text, (list, tuple)):
+ return [self._clean(item) for item in text]
+ elif isinstance(text, str):
+ return self._clean(text)
+ else:
+ raise TypeError('text must be a string or a list of strings')
+
+ def transform(self, results: dict) -> dict:
+ """Method to clean the input text data."""
+ for key in self.keys:
+ results[key] = self.clean(results[key])
+ return results
+
+
+@TRANSFORMS.register_module()
+class OFAAddObjects(BaseTransform):
+
+ def transform(self, results: dict) -> dict:
+ if 'objects' not in results:
+ raise ValueError(
+ 'Some OFA fine-tuned models requires `objects` field in the '
+ 'dataset, which is generated by VinVL. Or please use '
+ 'zero-shot configs. See '
+ 'https://github.com/OFA-Sys/OFA/issues/189')
+
+ if 'question' in results:
+ prompt = '{} object: {}'.format(
+ results['question'],
+ ' '.join(results['objects']),
+ )
+ results['decoder_prompt'] = prompt
+ results['question'] = prompt
+
+
+@TRANSFORMS.register_module()
+class RandomTranslatePad(BaseTransform):
+
+ def __init__(self, size=640, aug_translate=False):
+ self.size = size
+ self.aug_translate = aug_translate
+
+ @cache_randomness
+ def rand_translate_params(self, dh, dw):
+ top = np.random.randint(0, dh)
+ left = np.random.randint(0, dw)
+ return top, left
+
+ def transform(self, results: dict) -> dict:
+ img = results['img']
+ h, w = img.shape[:-1]
+ dw = self.size - w
+ dh = self.size - h
+ if self.aug_translate:
+ top, left = self.rand_translate_params(dh, dw)
+ else:
+ top = round(dh / 2.0 - 0.1)
+ left = round(dw / 2.0 - 0.1)
+
+ out_img = np.zeros((self.size, self.size, 3), dtype=np.float32)
+ out_img[top:top + h, left:left + w, :] = img
+ results['img'] = out_img
+ results['img_shape'] = (self.size, self.size)
+
+ # translate box
+ if 'gt_bboxes' in results.keys():
+ for i in range(len(results['gt_bboxes'])):
+ box = results['gt_bboxes'][i]
+ box[0], box[2] = box[0] + left, box[2] + left
+ box[1], box[3] = box[1] + top, box[3] + top
+ results['gt_bboxes'][i] = box
+
+ return results
diff --git a/mmpretrain/datasets/transforms/utils.py b/mmpretrain/datasets/transforms/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7940486fc9904c14f5a5a4a959022c11456c968
--- /dev/null
+++ b/mmpretrain/datasets/transforms/utils.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from typing import List, Union
+
+from mmcv.transforms import BaseTransform
+
+PIPELINE_TYPE = List[Union[dict, BaseTransform]]
+
+
+def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
+ """Returns the index of the transform in a pipeline.
+
+ Args:
+ pipeline (List[dict] | List[BaseTransform]): The transforms list.
+ target (str): The target transform class name.
+
+ Returns:
+ int: The transform index. Returns -1 if not found.
+ """
+ for i, transform in enumerate(pipeline):
+ if isinstance(transform, dict):
+ if isinstance(transform['type'], type):
+ if transform['type'].__name__ == target:
+ return i
+ else:
+ if transform['type'] == target:
+ return i
+ else:
+ if transform.__class__.__name__ == target:
+ return i
+
+ return -1
+
+
+def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
+ """Remove the target transform type from the pipeline.
+
+ Args:
+ pipeline (List[dict] | List[BaseTransform]): The transforms list.
+ target (str): The target transform class name.
+ inplace (bool): Whether to modify the pipeline inplace.
+
+ Returns:
+ The modified transform.
+ """
+ idx = get_transform_idx(pipeline, target)
+ if not inplace:
+ pipeline = copy.deepcopy(pipeline)
+ while idx >= 0:
+ pipeline.pop(idx)
+ idx = get_transform_idx(pipeline, target)
+
+ return pipeline
diff --git a/mmpretrain/datasets/transforms/wrappers.py b/mmpretrain/datasets/transforms/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0dfd730b4db0dc80ed315b79658cfbf683e4035
--- /dev/null
+++ b/mmpretrain/datasets/transforms/wrappers.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from typing import Callable, List, Union
+
+from mmcv.transforms import BaseTransform, Compose
+
+from mmpretrain.registry import TRANSFORMS
+
+# Define type of transform or transform config
+Transform = Union[dict, Callable[[dict], dict]]
+
+
+@TRANSFORMS.register_module()
+class MultiView(BaseTransform):
+ """A transform wrapper for multiple views of an image.
+
+ Args:
+ transforms (list[dict | callable], optional): Sequence of transform
+ object or config dict to be wrapped.
+ mapping (dict): A dict that defines the input key mapping.
+ The keys corresponds to the inner key (i.e., kwargs of the
+ ``transform`` method), and should be string type. The values
+ corresponds to the outer keys (i.e., the keys of the
+ data/results), and should have a type of string, list or dict.
+ None means not applying input mapping. Default: None.
+ allow_nonexist_keys (bool): If False, the outer keys in the mapping
+ must exist in the input data, or an exception will be raised.
+ Default: False.
+
+ Examples:
+ >>> # Example 1: MultiViews 1 pipeline with 2 views
+ >>> pipeline = [
+ >>> dict(type='MultiView',
+ >>> num_views=2,
+ >>> transforms=[
+ >>> [
+ >>> dict(type='Resize', scale=224))],
+ >>> ])
+ >>> ]
+ >>> # Example 2: MultiViews 2 pipelines, the first with 2 views,
+ >>> # the second with 6 views
+ >>> pipeline = [
+ >>> dict(type='MultiView',
+ >>> num_views=[2, 6],
+ >>> transforms=[
+ >>> [
+ >>> dict(type='Resize', scale=224)],
+ >>> [
+ >>> dict(type='Resize', scale=224),
+ >>> dict(type='RandomSolarize')],
+ >>> ])
+ >>> ]
+ """
+
+ def __init__(self, transforms: List[List[Transform]],
+ num_views: Union[int, List[int]]) -> None:
+
+ if isinstance(num_views, int):
+ num_views = [num_views]
+ assert isinstance(num_views, List)
+ assert len(num_views) == len(transforms)
+ self.num_views = num_views
+
+ self.pipelines = []
+ for trans in transforms:
+ pipeline = Compose(trans)
+ self.pipelines.append(pipeline)
+
+ self.transforms = []
+ for i in range(len(num_views)):
+ self.transforms.extend([self.pipelines[i]] * num_views[i])
+
+ def transform(self, results: dict) -> dict:
+ """Apply transformation to inputs.
+
+ Args:
+ results (dict): Result dict from previous pipelines.
+
+ Returns:
+ dict: Transformed results.
+ """
+ multi_views_outputs = dict(img=[])
+ for trans in self.transforms:
+ inputs = copy.deepcopy(results)
+ outputs = trans(inputs)
+
+ multi_views_outputs['img'].append(outputs['img'])
+ results.update(multi_views_outputs)
+ return results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__ + '('
+ for i, p in enumerate(self.pipelines):
+ repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n'
+ repr_str += str(p)
+ repr_str += ')'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class ApplyToList(BaseTransform):
+ """A transform wrapper to apply the wrapped transforms to a list of items.
+ For example, to load and resize a list of images.
+
+ Args:
+ transforms (list[dict | callable]): Sequence of transform config dict
+ to be wrapped.
+ scatter_key (str): The key to scatter data dict. If the field is a
+ list, scatter the list to multiple data dicts to do transformation.
+ collate_keys (List[str]): The keys to collate from multiple data dicts.
+ The fields in ``collate_keys`` will be composed into a list after
+ transformation, and the other fields will be adopted from the
+ first data dict.
+ """
+
+ def __init__(self, transforms, scatter_key, collate_keys):
+ super().__init__()
+
+ self.transforms = Compose([TRANSFORMS.build(t) for t in transforms])
+ self.scatter_key = scatter_key
+ self.collate_keys = set(collate_keys)
+ self.collate_keys.add(self.scatter_key)
+
+ def transform(self, results: dict):
+ scatter_field = results.get(self.scatter_key)
+
+ if isinstance(scatter_field, list):
+ scattered_results = []
+ for item in scatter_field:
+ single_results = copy.deepcopy(results)
+ single_results[self.scatter_key] = item
+ scattered_results.append(self.transforms(single_results))
+
+ final_output = scattered_results[0]
+
+ # merge output list to single output
+ for key in scattered_results[0].keys():
+ if key in self.collate_keys:
+ final_output[key] = [
+ single[key] for single in scattered_results
+ ]
+ return final_output
+ else:
+ return self.transforms(results)
diff --git a/mmpretrain/datasets/utils.py b/mmpretrain/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcb60e432c374c1a904700a7348f706fa0e523eb
--- /dev/null
+++ b/mmpretrain/datasets/utils.py
@@ -0,0 +1,243 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import gzip
+import hashlib
+import os
+import os.path
+import shutil
+import tarfile
+import tempfile
+import urllib.error
+import urllib.request
+import zipfile
+
+from mmengine.fileio import LocalBackend, get_file_backend
+
+__all__ = [
+ 'rm_suffix', 'check_integrity', 'download_and_extract_archive',
+ 'open_maybe_compressed_file'
+]
+
+
+def rm_suffix(s, suffix=None):
+ if suffix is None:
+ return s[:s.rfind('.')]
+ else:
+ return s[:s.rfind(suffix)]
+
+
+def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
+ md5 = hashlib.md5()
+ backend = get_file_backend(fpath, enable_singleton=True)
+ if isinstance(backend, LocalBackend):
+ # Enable chunk update for local file.
+ with open(fpath, 'rb') as f:
+ for chunk in iter(lambda: f.read(chunk_size), b''):
+ md5.update(chunk)
+ else:
+ md5.update(backend.get(fpath))
+ return md5.hexdigest()
+
+
+def check_md5(fpath, md5, **kwargs):
+ return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath, md5=None):
+ if not os.path.isfile(fpath):
+ return False
+ if md5 is None:
+ return True
+ return check_md5(fpath, md5)
+
+
+def download_url_to_file(url, dst, hash_prefix=None, progress=True):
+ """Download object at the given URL to a local path.
+
+ Modified from
+ https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file
+
+ Args:
+ url (str): URL of the object to download
+ dst (str): Full path where object will be saved,
+ e.g. ``/tmp/temporary_file``
+ hash_prefix (string, optional): If not None, the SHA256 downloaded
+ file should start with ``hash_prefix``. Defaults to None.
+ progress (bool): whether or not to display a progress bar to stderr.
+ Defaults to True
+ """
+ file_size = None
+ req = urllib.request.Request(url)
+ u = urllib.request.urlopen(req)
+ meta = u.info()
+ if hasattr(meta, 'getheaders'):
+ content_length = meta.getheaders('Content-Length')
+ else:
+ content_length = meta.get_all('Content-Length')
+ if content_length is not None and len(content_length) > 0:
+ file_size = int(content_length[0])
+
+ # We deliberately save it in a temp file and move it after download is
+ # complete. This prevents a local file being overridden by a broken
+ # download.
+ dst = os.path.expanduser(dst)
+ dst_dir = os.path.dirname(dst)
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
+
+ import rich.progress
+ columns = [
+ rich.progress.DownloadColumn(),
+ rich.progress.BarColumn(bar_width=None),
+ rich.progress.TimeRemainingColumn(),
+ ]
+ try:
+ if hash_prefix is not None:
+ sha256 = hashlib.sha256()
+ with rich.progress.Progress(*columns) as pbar:
+ task = pbar.add_task('download', total=file_size, visible=progress)
+ while True:
+ buffer = u.read(8192)
+ if len(buffer) == 0:
+ break
+ f.write(buffer)
+ if hash_prefix is not None:
+ sha256.update(buffer)
+ pbar.update(task, advance=len(buffer))
+
+ f.close()
+ if hash_prefix is not None:
+ digest = sha256.hexdigest()
+ if digest[:len(hash_prefix)] != hash_prefix:
+ raise RuntimeError(
+ 'invalid hash value (expected "{}", got "{}")'.format(
+ hash_prefix, digest))
+ shutil.move(f.name, dst)
+ finally:
+ f.close()
+ if os.path.exists(f.name):
+ os.remove(f.name)
+
+
+def download_url(url, root, filename=None, md5=None):
+ """Download a file from a url and place it in root.
+
+ Args:
+ url (str): URL to download file from.
+ root (str): Directory to place downloaded file in.
+ filename (str | None): Name to save the file under.
+ If filename is None, use the basename of the URL.
+ md5 (str | None): MD5 checksum of the download.
+ If md5 is None, download without md5 check.
+ """
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+
+ os.makedirs(root, exist_ok=True)
+
+ if check_integrity(fpath, md5):
+ print(f'Using downloaded and verified file: {fpath}')
+ else:
+ try:
+ print(f'Downloading {url} to {fpath}')
+ download_url_to_file(url, fpath)
+ except (urllib.error.URLError, IOError) as e:
+ if url[:5] == 'https':
+ url = url.replace('https:', 'http:')
+ print('Failed download. Trying https -> http instead.'
+ f' Downloading {url} to {fpath}')
+ download_url_to_file(url, fpath)
+ else:
+ raise e
+ # check integrity of downloaded file
+ if not check_integrity(fpath, md5):
+ raise RuntimeError('File not found or corrupted.')
+
+
+def _is_tarxz(filename):
+ return filename.endswith('.tar.xz')
+
+
+def _is_tar(filename):
+ return filename.endswith('.tar')
+
+
+def _is_targz(filename):
+ return filename.endswith('.tar.gz')
+
+
+def _is_tgz(filename):
+ return filename.endswith('.tgz')
+
+
+def _is_gzip(filename):
+ return filename.endswith('.gz') and not filename.endswith('.tar.gz')
+
+
+def _is_zip(filename):
+ return filename.endswith('.zip')
+
+
+def extract_archive(from_path, to_path=None, remove_finished=False):
+ if to_path is None:
+ to_path = os.path.dirname(from_path)
+
+ if _is_tar(from_path):
+ with tarfile.open(from_path, 'r') as tar:
+ tar.extractall(path=to_path)
+ elif _is_targz(from_path) or _is_tgz(from_path):
+ with tarfile.open(from_path, 'r:gz') as tar:
+ tar.extractall(path=to_path)
+ elif _is_tarxz(from_path):
+ with tarfile.open(from_path, 'r:xz') as tar:
+ tar.extractall(path=to_path)
+ elif _is_gzip(from_path):
+ to_path = os.path.join(
+ to_path,
+ os.path.splitext(os.path.basename(from_path))[0])
+ with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
+ out_f.write(zip_f.read())
+ elif _is_zip(from_path):
+ with zipfile.ZipFile(from_path, 'r') as z:
+ z.extractall(to_path)
+ else:
+ raise ValueError(f'Extraction of {from_path} not supported')
+
+ if remove_finished:
+ os.remove(from_path)
+
+
+def download_and_extract_archive(url,
+ download_root,
+ extract_root=None,
+ filename=None,
+ md5=None,
+ remove_finished=False):
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if not filename:
+ filename = os.path.basename(url)
+
+ download_url(url, download_root, filename, md5)
+
+ archive = os.path.join(download_root, filename)
+ print(f'Extracting {archive} to {extract_root}')
+ extract_archive(archive, extract_root, remove_finished)
+
+
+def open_maybe_compressed_file(path: str):
+ """Return a file object that possibly decompresses 'path' on the fly.
+
+ Decompression occurs when argument `path` is a string and ends with '.gz'
+ or '.xz'.
+ """
+ if not isinstance(path, str):
+ return path
+ if path.endswith('.gz'):
+ import gzip
+ return gzip.open(path, 'rb')
+ if path.endswith('.xz'):
+ import lzma
+ return lzma.open(path, 'rb')
+ return open(path, 'rb')
diff --git a/mmpretrain/datasets/vg_vqa.py b/mmpretrain/datasets/vg_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d83884c804086c060bcfe27e833bff28dc28e9e
--- /dev/null
+++ b/mmpretrain/datasets/vg_vqa.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from mmengine.fileio import load
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class VGVQA(BaseDataset):
+ """Visual Genome VQA dataset."""
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list.
+
+ Compare to BaseDataset, the only difference is that coco_vqa annotation
+ file is already a list of data. There is no 'metainfo'.
+ """
+
+ raw_data_list = load(self.ann_file)
+ if not isinstance(raw_data_list, list):
+ raise TypeError(
+ f'The VQA annotations loaded from annotation file '
+ f'should be a dict, but got {type(raw_data_list)}!')
+
+ # load and parse data_infos.
+ data_list = []
+ for raw_data_info in raw_data_list:
+ # parse raw data information to target format
+ data_info = self.parse_data_info(raw_data_info)
+ if isinstance(data_info, dict):
+ # For VQA tasks, each `data_info` looks like:
+ # {
+ # "question_id": 986769,
+ # "question": "How many people are there?",
+ # "answer": "two",
+ # "image": "image/1.jpg",
+ # "dataset": "vg"
+ # }
+
+ # change 'image' key to 'img_path'
+ # TODO: This process will be removed, after the annotation file
+ # is preprocess.
+ data_info['img_path'] = data_info['image']
+ del data_info['image']
+
+ if 'answer' in data_info:
+ # add answer_weight & answer_count, delete duplicate answer
+ if data_info['dataset'] == 'vqa':
+ answer_weight = {}
+ for answer in data_info['answer']:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1 / len(
+ data_info['answer'])
+ else:
+ answer_weight[answer] = 1 / len(
+ data_info['answer'])
+
+ data_info['answer'] = list(answer_weight.keys())
+ data_info['answer_weight'] = list(
+ answer_weight.values())
+ data_info['answer_count'] = len(answer_weight)
+
+ elif data_info['dataset'] == 'vg':
+ data_info['answers'] = [data_info['answer']]
+ data_info['answer_weight'] = [0.2]
+ data_info['answer_count'] = 1
+
+ data_list.append(data_info)
+
+ else:
+ raise TypeError(
+ f'Each VQA data element loaded from annotation file '
+ f'should be a dict, but got {type(data_info)}!')
+
+ return data_list
diff --git a/mmpretrain/datasets/visual_genome.py b/mmpretrain/datasets/visual_genome.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c33b86c4f81d0be0f2830618ad100196b461dcf
--- /dev/null
+++ b/mmpretrain/datasets/visual_genome.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import re
+from itertools import chain
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class VisualGenomeQA(BaseDataset):
+ """Visual Genome Question Answering dataset.
+
+ dataset structure: ::
+
+ data_root
+ ├── image
+ │ ├── 1.jpg
+ │ ├── 2.jpg
+ │ └── ...
+ └── question_answers.json
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images. Defaults to ``"image"``.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to ``"question_answers.json"``.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str = 'image',
+ ann_file: str = 'question_answers.json',
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def _create_image_index(self):
+ img_prefix = self.data_prefix['img_path']
+
+ files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
+ image_index = {}
+ for file in files:
+ image_id = re.findall(r'\d+', file)
+ if len(image_id) > 0:
+ image_id = int(image_id[-1])
+ image_index[image_id] = mmengine.join_path(img_prefix, file)
+
+ return image_index
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ annotations = mmengine.load(self.ann_file)
+
+ # The original Visual Genome annotation file and question file includes
+ # only image id but no image file paths.
+ self.image_index = self._create_image_index()
+
+ data_list = []
+ for qas in chain.from_iterable(ann['qas'] for ann in annotations):
+ # ann example
+ # {
+ # 'id': 1,
+ # 'qas': [
+ # {
+ # 'a_objects': [],
+ # 'question': 'What color is the clock?',
+ # 'image_id': 1,
+ # 'qa_id': 986768,
+ # 'answer': 'Two.',
+ # 'q_objects': [],
+ # }
+ # ...
+ # ]
+ # }
+
+ data_info = {
+ 'img_path': self.image_index[qas['image_id']],
+ 'quesiton': qas['quesiton'],
+ 'question_id': qas['question_id'],
+ 'image_id': qas['image_id'],
+ 'gt_answer': [qas['answer']],
+ }
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/vizwiz.py b/mmpretrain/datasets/vizwiz.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5dd394524cac5ad514351ac2a93286c75e1b17
--- /dev/null
+++ b/mmpretrain/datasets/vizwiz.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import Counter
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class VizWiz(BaseDataset):
+ """VizWiz dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to an empty string.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str,
+ ann_file: str = '',
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ annotations = mmengine.load(self.ann_file)
+
+ data_list = []
+ for ann in annotations:
+ # {
+ # "image": "VizWiz_val_00000001.jpg",
+ # "question": "Can you tell me what this medicine is please?",
+ # "answers": [
+ # {
+ # "answer": "no",
+ # "answer_confidence": "yes"
+ # },
+ # {
+ # "answer": "unanswerable",
+ # "answer_confidence": "yes"
+ # },
+ # {
+ # "answer": "night time",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "unanswerable",
+ # "answer_confidence": "yes"
+ # },
+ # {
+ # "answer": "night time",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "night time cold medicine",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "night time",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "night time",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "night time",
+ # "answer_confidence": "maybe"
+ # },
+ # {
+ # "answer": "night time medicine",
+ # "answer_confidence": "yes"
+ # }
+ # ],
+ # "answer_type": "other",
+ # "answerable": 1
+ # },
+ data_info = dict()
+ data_info['question'] = ann['question']
+ data_info['img_path'] = mmengine.join_path(
+ self.data_prefix['img_path'], ann['image'])
+
+ if 'answerable' not in ann:
+ data_list.append(data_info)
+ else:
+ if ann['answerable'] == 1:
+ # add answer_weight & answer_count, delete duplicate answer
+ answers = []
+ for item in ann.pop('answers'):
+ if item['answer_confidence'] == 'yes' and item[
+ 'answer'] != 'unanswerable':
+ answers.append(item['answer'])
+ count = Counter(answers)
+ answer_weight = [i / len(answers) for i in count.values()]
+ data_info['gt_answer'] = list(count.keys())
+ data_info['gt_answer_weight'] = answer_weight
+ # data_info.update(ann)
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/datasets/voc.py b/mmpretrain/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..39544de7a1794a2d965189c692f652cc56b218f9
--- /dev/null
+++ b/mmpretrain/datasets/voc.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import xml.etree.ElementTree as ET
+from typing import List, Optional, Union
+
+from mmengine import get_file_backend, list_from_file
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import DATASETS
+from .base_dataset import expanduser
+from .categories import VOC2007_CATEGORIES
+from .multi_label import MultiLabelDataset
+
+
+@DATASETS.register_module()
+class VOC(MultiLabelDataset):
+ """`Pascal VOC `_ Dataset.
+
+ After decompression, the dataset directory structure is as follows:
+
+ VOC dataset directory: ::
+
+ VOC2007
+ ├── JPEGImages
+ │ ├── xxx.jpg
+ │ ├── xxy.jpg
+ │ └── ...
+ ├── Annotations
+ │ ├── xxx.xml
+ │ ├── xxy.xml
+ │ └── ...
+ └── ImageSets
+ └── Main
+ ├── train.txt
+ ├── val.txt
+ ├── trainval.txt
+ ├── test.txt
+ └── ...
+
+ Extra difficult label is in VOC annotations, we will use
+ `gt_label_difficult` to record the difficult labels in each sample
+ and corresponding evaluation should take care of this field
+ to calculate metrics. Usually, difficult labels are reckoned as
+ negative in defaults.
+
+ Args:
+ data_root (str): The root directory for VOC dataset.
+ split (str, optional): The dataset split, supports "train",
+ "val", "trainval", and "test". Default to "trainval".
+ image_set_path (str, optional): The path of image set, The file which
+ lists image ids of the sub dataset, and this path is relative
+ to ``data_root``. Default to ''.
+ data_prefix (dict): Prefix for data and annotation, keyword
+ 'img_path' and 'ann_path' can be set. Defaults to be
+ ``dict(img_path='JPEGImages', ann_path='Annotations')``.
+ metainfo (dict, optional): Meta information for dataset, such as
+ categories information. Defaults to None.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+
+ Examples:
+ >>> from mmpretrain.datasets import VOC
+ >>> train_dataset = VOC(data_root='data/VOC2007', split='trainval')
+ >>> train_dataset
+ Dataset VOC
+ Number of samples: 5011
+ Number of categories: 20
+ Prefix of dataset: data/VOC2007
+ Path of image set: data/VOC2007/ImageSets/Main/trainval.txt
+ Prefix of images: data/VOC2007/JPEGImages
+ Prefix of annotations: data/VOC2007/Annotations
+ >>> test_dataset = VOC(data_root='data/VOC2007', split='test')
+ >>> test_dataset
+ Dataset VOC
+ Number of samples: 4952
+ Number of categories: 20
+ Prefix of dataset: data/VOC2007
+ Path of image set: data/VOC2007/ImageSets/Main/test.txt
+ Prefix of images: data/VOC2007/JPEGImages
+ Prefix of annotations: data/VOC2007/Annotations
+ """ # noqa: E501
+
+ METAINFO = {'classes': VOC2007_CATEGORIES}
+
+ def __init__(self,
+ data_root: str,
+ split: str = 'trainval',
+ image_set_path: str = '',
+ data_prefix: Union[str, dict] = dict(
+ img_path='JPEGImages', ann_path='Annotations'),
+ test_mode: bool = False,
+ metainfo: Optional[dict] = None,
+ **kwargs):
+
+ self.backend = get_file_backend(data_root, enable_singleton=True)
+
+ if split:
+ splits = ['train', 'val', 'trainval', 'test']
+ assert split in splits, \
+ f"The split must be one of {splits}, but get '{split}'"
+ self.split = split
+
+ if not data_prefix:
+ data_prefix = dict(
+ img_path='JPEGImages', ann_path='Annotations')
+ if not image_set_path:
+ image_set_path = self.backend.join_path(
+ 'ImageSets', 'Main', f'{split}.txt')
+
+ # To handle the BC-breaking
+ if (split == 'train' or split == 'trainval') and test_mode:
+ logger = MMLogger.get_current_instance()
+ logger.warning(f'split="{split}" but test_mode=True. '
+ f'The {split} set will be used.')
+
+ if isinstance(data_prefix, str):
+ data_prefix = dict(img_path=expanduser(data_prefix))
+ assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \
+ '`data_prefix` must be a dict with key img_path'
+
+ if (split and split not in ['val', 'test']) or not test_mode:
+ assert 'ann_path' in data_prefix and data_prefix[
+ 'ann_path'] is not None, \
+ '"ann_path" must be set in `data_prefix`' \
+ 'when validation or test set is used.'
+
+ self.data_root = data_root
+ self.image_set_path = self.backend.join_path(data_root, image_set_path)
+
+ super().__init__(
+ ann_file='',
+ metainfo=metainfo,
+ data_root=data_root,
+ data_prefix=data_prefix,
+ test_mode=test_mode,
+ **kwargs)
+
+ @property
+ def ann_prefix(self):
+ """The prefix of images."""
+ if 'ann_path' in self.data_prefix:
+ return self.data_prefix['ann_path']
+ else:
+ return None
+
+ def _get_labels_from_xml(self, img_id):
+ """Get gt_labels and labels_difficult from xml file."""
+ xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml')
+ content = self.backend.get(xml_path)
+ root = ET.fromstring(content)
+
+ labels, labels_difficult = set(), set()
+ for obj in root.findall('object'):
+ label_name = obj.find('name').text
+ # in case customized dataset has wrong labels
+ # or CLASSES has been override.
+ if label_name not in self.CLASSES:
+ continue
+ label = self.class_to_idx[label_name]
+ difficult = int(obj.find('difficult').text)
+ if difficult:
+ labels_difficult.add(label)
+ else:
+ labels.add(label)
+
+ return list(labels), list(labels_difficult)
+
+ def load_data_list(self):
+ """Load images and ground truth labels."""
+ data_list = []
+ img_ids = list_from_file(self.image_set_path)
+
+ for img_id in img_ids:
+ img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg')
+
+ labels, labels_difficult = None, None
+ if self.ann_prefix is not None:
+ labels, labels_difficult = self._get_labels_from_xml(img_id)
+
+ info = dict(
+ img_path=img_path,
+ gt_label=labels,
+ gt_label_difficult=labels_difficult)
+ data_list.append(info)
+
+ return data_list
+
+ def extra_repr(self) -> List[str]:
+ """The extra repr information of the dataset."""
+ body = [
+ f'Prefix of dataset: \t{self.data_root}',
+ f'Path of image set: \t{self.image_set_path}',
+ f'Prefix of images: \t{self.img_prefix}',
+ f'Prefix of annotations: \t{self.ann_prefix}'
+ ]
+
+ return body
diff --git a/mmpretrain/datasets/vsr.py b/mmpretrain/datasets/vsr.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b109592bd020d57e3db8f2ff610901e2a1d9f31
--- /dev/null
+++ b/mmpretrain/datasets/vsr.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class VSR(BaseDataset):
+ """VSR: Visual Spatial Reasoning dataset.
+
+ Args:
+ data_root (str): The root directory for ``data_prefix``, ``ann_file``
+ and ``question_file``.
+ data_prefix (str): The directory of images.
+ ann_file (str, optional): Annotation file path for training and
+ validation. Defaults to an empty string.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def __init__(self,
+ data_root: str,
+ data_prefix: str,
+ ann_file: str = '',
+ **kwarg):
+ super().__init__(
+ data_root=data_root,
+ data_prefix=dict(img_path=data_prefix),
+ ann_file=ann_file,
+ **kwarg,
+ )
+
+ def load_data_list(self) -> List[dict]:
+ """Load data list."""
+ annotations = mmengine.load(self.ann_file)
+
+ data_list = []
+ for ann in annotations:
+ # ann example
+ # {
+ # "image": "train2017/000000372029.jpg",
+ # "question": "The dog is on the surfboard.",
+ # "answer": true
+ # }
+ data_info = dict()
+ data_info['img_path'] = mmengine.join_path(
+ self.data_prefix['img_path'], ann['image'])
+ data_info['question'] = ann['question']
+ data_info['gt_answer'] = 'yes' if ann['answer'] else 'no'
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/engine/__init__.py b/mmpretrain/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332fea0909b4abdc6a83cf7662ea916a777d99dd
--- /dev/null
+++ b/mmpretrain/engine/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hooks import * # noqa: F401, F403
+from .optimizers import * # noqa: F401, F403
+from .runners import * # noqa: F401, F403
+from .schedulers import * # noqa: F401, F403
diff --git a/mmpretrain/engine/hooks/__init__.py b/mmpretrain/engine/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9e22be7e96d636f202066f2e00e7699b730619
--- /dev/null
+++ b/mmpretrain/engine/hooks/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .class_num_check_hook import ClassNumCheckHook
+from .densecl_hook import DenseCLHook
+from .ema_hook import EMAHook
+from .margin_head_hooks import SetAdaptiveMarginsHook
+from .precise_bn_hook import PreciseBNHook
+from .retriever_hooks import PrepareProtoBeforeValLoopHook
+from .simsiam_hook import SimSiamHook
+from .swav_hook import SwAVHook
+from .switch_recipe_hook import SwitchRecipeHook
+from .visualization_hook import VisualizationHook
+from .warmup_param_hook import WarmupParamHook
+
+__all__ = [
+ 'ClassNumCheckHook', 'PreciseBNHook', 'VisualizationHook',
+ 'SwitchRecipeHook', 'PrepareProtoBeforeValLoopHook',
+ 'SetAdaptiveMarginsHook', 'EMAHook', 'SimSiamHook', 'DenseCLHook',
+ 'SwAVHook', 'WarmupParamHook'
+]
diff --git a/mmpretrain/engine/hooks/class_num_check_hook.py b/mmpretrain/engine/hooks/class_num_check_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..38170d6604810c575aa5c2c9435c0b75cfa761b2
--- /dev/null
+++ b/mmpretrain/engine/hooks/class_num_check_hook.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved
+from mmengine.hooks import Hook
+from mmengine.utils import is_seq_of
+
+from mmpretrain.registry import HOOKS
+
+
+@HOOKS.register_module()
+class ClassNumCheckHook(Hook):
+ """Class Number Check HOOK."""
+
+ def _check_head(self, runner, dataset):
+ """Check whether the `num_classes` in head matches the length of
+ `CLASSES` in `dataset`.
+
+ Args:
+ runner (obj:`Runner`): runner object.
+ dataset (obj: `BaseDataset`): the dataset to check.
+ """
+ model = runner.model
+ if dataset.CLASSES is None:
+ runner.logger.warning(
+ f'Please set class information in `metainfo` '
+ f'in the {dataset.__class__.__name__} and'
+ f'check if it is consistent with the `num_classes` '
+ f'of head')
+ else:
+ assert is_seq_of(dataset.CLASSES, str), \
+ (f'Class information in `metainfo` in '
+ f'{dataset.__class__.__name__} should be a tuple of str.')
+ for _, module in model.named_modules():
+ if hasattr(module, 'num_classes'):
+ assert module.num_classes == len(dataset.CLASSES), \
+ (f'The `num_classes` ({module.num_classes}) in '
+ f'{module.__class__.__name__} of '
+ f'{model.__class__.__name__} does not matches '
+ f'the length of class information in `metainfo` '
+ f'{len(dataset.CLASSES)}) in '
+ f'{dataset.__class__.__name__}')
+
+ def before_train(self, runner):
+ """Check whether the training dataset is compatible with head.
+
+ Args:
+ runner (obj: `IterBasedRunner`): Iter based Runner.
+ """
+ self._check_head(runner, runner.train_dataloader.dataset)
+
+ def before_val(self, runner):
+ """Check whether the validation dataset is compatible with head.
+
+ Args:
+ runner (obj:`IterBasedRunner`): Iter based Runner.
+ """
+ self._check_head(runner, runner.val_dataloader.dataset)
+
+ def before_test(self, runner):
+ """Check whether the test dataset is compatible with head.
+
+ Args:
+ runner (obj:`IterBasedRunner`): Iter based Runner.
+ """
+ self._check_head(runner, runner.test_dataloader.dataset)
diff --git a/mmpretrain/engine/hooks/densecl_hook.py b/mmpretrain/engine/hooks/densecl_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c7e17d3419cbc2a540d3aecd81e223eed670df2
--- /dev/null
+++ b/mmpretrain/engine/hooks/densecl_hook.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+from mmengine.hooks import Hook
+
+from mmpretrain.registry import HOOKS
+from mmpretrain.utils import get_ori_model
+
+
+@HOOKS.register_module()
+class DenseCLHook(Hook):
+ """Hook for DenseCL.
+
+ This hook includes ``loss_lambda`` warmup in DenseCL.
+ Borrowed from the authors' code: ``_.
+
+ Args:
+ start_iters (int): The number of warmup iterations to set
+ ``loss_lambda=0``. Defaults to 1000.
+ """
+
+ def __init__(self, start_iters: int = 1000) -> None:
+ self.start_iters = start_iters
+
+ def before_train(self, runner) -> None:
+ """Obtain ``loss_lambda`` from algorithm."""
+ assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \
+ "The runner must have attribute \"loss_lambda\" in DenseCL."
+ self.loss_lambda = get_ori_model(runner.model).loss_lambda
+
+ def before_train_iter(self,
+ runner,
+ batch_idx: int,
+ data_batch: Optional[Sequence[dict]] = None) -> None:
+ """Adjust ``loss_lambda`` every train iter."""
+ assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \
+ "The runner must have attribute \"loss_lambda\" in DenseCL."
+ cur_iter = runner.iter
+ if cur_iter >= self.start_iters:
+ get_ori_model(runner.model).loss_lambda = self.loss_lambda
+ else:
+ get_ori_model(runner.model).loss_lambda = 0.
diff --git a/mmpretrain/engine/hooks/ema_hook.py b/mmpretrain/engine/hooks/ema_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..284d211b628c411f0eb712d1c558dc6aa2eb8996
--- /dev/null
+++ b/mmpretrain/engine/hooks/ema_hook.py
@@ -0,0 +1,216 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import itertools
+import warnings
+from typing import Dict, Optional
+
+from mmengine.hooks import EMAHook as BaseEMAHook
+from mmengine.logging import MMLogger
+from mmengine.runner import Runner
+
+from mmpretrain.registry import HOOKS
+
+
+@HOOKS.register_module()
+class EMAHook(BaseEMAHook):
+ """A Hook to apply Exponential Moving Average (EMA) on the model during
+ training.
+
+ Comparing with :class:`mmengine.hooks.EMAHook`, this hook accepts
+ ``evaluate_on_ema`` and ``evaluate_on_origin`` arguments. By default, the
+ ``evaluate_on_ema`` is enabled, and if you want to do validation and
+ testing on both original and EMA models, please set both arguments
+ ``True``.
+
+ Note:
+ - EMAHook takes priority over CheckpointHook.
+ - The original model parameters are actually saved in ema field after
+ train.
+ - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
+
+ Args:
+ ema_type (str): The type of EMA strategy to use. You can find the
+ supported strategies in :mod:`mmengine.model.averaged_model`.
+ Defaults to 'ExponentialMovingAverage'.
+ strict_load (bool): Whether to strictly enforce that the keys of
+ ``state_dict`` in checkpoint match the keys returned by
+ ``self.module.state_dict``. Defaults to False.
+ Changed in v0.3.0.
+ begin_iter (int): The number of iteration to enable ``EMAHook``.
+ Defaults to 0.
+ begin_epoch (int): The number of epoch to enable ``EMAHook``.
+ Defaults to 0.
+ evaluate_on_ema (bool): Whether to evaluate (validate and test)
+ on EMA model during val-loop and test-loop. Defaults to True.
+ evaluate_on_origin (bool): Whether to evaluate (validate and test)
+ on the original model during val-loop and test-loop.
+ Defaults to False.
+ **kwargs: Keyword arguments passed to subclasses of
+ :obj:`BaseAveragedModel`
+ """
+
+ priority = 'NORMAL'
+
+ def __init__(self,
+ ema_type: str = 'ExponentialMovingAverage',
+ strict_load: bool = False,
+ begin_iter: int = 0,
+ begin_epoch: int = 0,
+ evaluate_on_ema: bool = True,
+ evaluate_on_origin: bool = False,
+ **kwargs):
+ super().__init__(
+ ema_type=ema_type,
+ strict_load=strict_load,
+ begin_iter=begin_iter,
+ begin_epoch=begin_epoch,
+ **kwargs)
+
+ if not evaluate_on_ema and not evaluate_on_origin:
+ warnings.warn(
+ 'Automatically set `evaluate_on_origin=True` since the '
+ '`evaluate_on_ema` is disabled. If you want to disable '
+ 'all validation, please modify the `val_interval` of '
+ 'the `train_cfg`.', UserWarning)
+ evaluate_on_origin = True
+
+ self.evaluate_on_ema = evaluate_on_ema
+ self.evaluate_on_origin = evaluate_on_origin
+ self.load_ema_from_ckpt = False
+
+ def before_train(self, runner) -> None:
+ super().before_train(runner)
+ if not runner._resume and self.load_ema_from_ckpt:
+ # If loaded EMA state dict but not want to resume training
+ # overwrite the EMA state dict with the source model.
+ MMLogger.get_current_instance().info(
+ 'Load from a checkpoint with EMA parameters but not '
+ 'resume training. Initialize the model parameters with '
+ 'EMA parameters')
+ for p_ema, p_src in zip(self._ema_params, self._src_params):
+ p_src.data.copy_(p_ema.data)
+
+ def before_val_epoch(self, runner) -> None:
+ """We load parameter values from ema model to source model before
+ validation.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ """
+ if self.evaluate_on_ema:
+ # Swap when evaluate on ema
+ self._swap_ema_parameters()
+
+ def after_val_epoch(self,
+ runner,
+ metrics: Optional[Dict[str, float]] = None) -> None:
+ """We recover source model's parameter from ema model after validation.
+
+ Args:
+ runner (Runner): The runner of the validation process.
+ metrics (Dict[str, float], optional): Evaluation results of all
+ metrics on validation dataset. The keys are the names of the
+ metrics, and the values are corresponding results.
+ """
+ if self.evaluate_on_ema:
+ # Swap when evaluate on ema
+ self._swap_ema_parameters()
+
+ if self.evaluate_on_ema and self.evaluate_on_origin:
+ # Re-evaluate if evaluate on both ema and origin.
+ val_loop = runner.val_loop
+
+ runner.model.eval()
+ for idx, data_batch in enumerate(val_loop.dataloader):
+ val_loop.run_iter(idx, data_batch)
+
+ # compute metrics
+ origin_metrics = val_loop.evaluator.evaluate(
+ len(val_loop.dataloader.dataset))
+
+ for k, v in origin_metrics.items():
+ runner.message_hub.update_scalar(f'val/{k}_origin', v)
+
+ def before_test_epoch(self, runner) -> None:
+ """We load parameter values from ema model to source model before test.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ """
+ if self.evaluate_on_ema:
+ # Swap when evaluate on ema
+ self._swap_ema_parameters()
+ MMLogger.get_current_instance().info('Start testing on EMA model.')
+ else:
+ MMLogger.get_current_instance().info(
+ 'Start testing on the original model.')
+
+ def after_test_epoch(self,
+ runner: Runner,
+ metrics: Optional[Dict[str, float]] = None) -> None:
+ """We recover source model's parameter from ema model after test.
+
+ Args:
+ runner (Runner): The runner of the testing process.
+ metrics (Dict[str, float], optional): Evaluation results of all
+ metrics on test dataset. The keys are the names of the
+ metrics, and the values are corresponding results.
+ """
+ if self.evaluate_on_ema:
+ # Swap when evaluate on ema
+ self._swap_ema_parameters()
+
+ if self.evaluate_on_ema and self.evaluate_on_origin:
+ # Re-evaluate if evaluate on both ema and origin.
+ MMLogger.get_current_instance().info(
+ 'Start testing on the original model.')
+ test_loop = runner.test_loop
+
+ runner.model.eval()
+ for idx, data_batch in enumerate(test_loop.dataloader):
+ test_loop.run_iter(idx, data_batch)
+
+ # compute metrics
+ origin_metrics = test_loop.evaluator.evaluate(
+ len(test_loop.dataloader.dataset))
+
+ for k, v in origin_metrics.items():
+ runner.message_hub.update_scalar(f'test/{k}_origin', v)
+
+ def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
+ """Resume ema parameters from checkpoint.
+
+ Args:
+ runner (Runner): The runner of the testing process.
+ """
+ from mmengine.runner.checkpoint import load_state_dict
+ if 'ema_state_dict' in checkpoint:
+ # The original model parameters are actually saved in ema
+ # field swap the weights back to resume ema state.
+ self._swap_ema_state_dict(checkpoint)
+ self.ema_model.load_state_dict(
+ checkpoint['ema_state_dict'], strict=self.strict_load)
+ self.load_ema_from_ckpt = True
+
+ # Support load checkpoint without ema state dict.
+ else:
+ load_state_dict(
+ self.ema_model.module,
+ copy.deepcopy(checkpoint['state_dict']),
+ strict=self.strict_load)
+
+ @property
+ def _src_params(self):
+ if self.ema_model.update_buffers:
+ return itertools.chain(self.src_model.parameters(),
+ self.src_model.buffers())
+ else:
+ return self.src_model.parameters()
+
+ @property
+ def _ema_params(self):
+ if self.ema_model.update_buffers:
+ return itertools.chain(self.ema_model.module.parameters(),
+ self.ema_model.module.buffers())
+ else:
+ return self.ema_model.module.parameters()
diff --git a/mmpretrain/engine/hooks/margin_head_hooks.py b/mmpretrain/engine/hooks/margin_head_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbeae7a347453153ff4ab3bef958acb549623f6f
--- /dev/null
+++ b/mmpretrain/engine/hooks/margin_head_hooks.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved
+import numpy as np
+from mmengine.hooks import Hook
+from mmengine.model import is_model_wrapper
+
+from mmpretrain.models.heads import ArcFaceClsHead
+from mmpretrain.registry import HOOKS
+
+
+@HOOKS.register_module()
+class SetAdaptiveMarginsHook(Hook):
+ r"""Set adaptive-margins in ArcFaceClsHead based on the power of
+ category-wise count.
+
+ A PyTorch implementation of paper `Google Landmark Recognition 2020
+ Competition Third Place Solution `_.
+ The margins will be
+ :math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`.
+ The `n` indicates the number of occurrences of a category.
+
+ Args:
+ margin_min (float): Lower bound of margins. Defaults to 0.05.
+ margin_max (float): Upper bound of margins. Defaults to 0.5.
+ power (float): The power of category freqercy. Defaults to -0.25.
+ """
+
+ def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None:
+ self.margin_min = margin_min
+ self.margin_max = margin_max
+ self.margin_range = margin_max - margin_min
+ self.p = power
+
+ def before_train(self, runner):
+ """change the margins in ArcFaceClsHead.
+
+ Args:
+ runner (obj: `Runner`): Runner.
+ """
+ model = runner.model
+ if is_model_wrapper(model):
+ model = model.module
+
+ if (hasattr(model, 'head')
+ and not isinstance(model.head, ArcFaceClsHead)):
+ raise ValueError(
+ 'Hook ``SetFreqPowAdvMarginsHook`` could only be used '
+ f'for ``ArcFaceClsHead``, but get {type(model.head)}')
+
+ # generate margins base on the dataset.
+ gt_labels = runner.train_dataloader.dataset.get_gt_labels()
+ label_count = np.bincount(gt_labels)
+ label_count[label_count == 0] = 1 # At least one occurrence
+ pow_freq = np.power(label_count, self.p)
+
+ min_f, max_f = pow_freq.min(), pow_freq.max()
+ normized_pow_freq = (pow_freq - min_f) / (max_f - min_f)
+ margins = normized_pow_freq * self.margin_range + self.margin_min
+
+ assert len(margins) == runner.model.head.num_classes
+
+ model.head.set_margins(margins)
diff --git a/mmpretrain/engine/hooks/precise_bn_hook.py b/mmpretrain/engine/hooks/precise_bn_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb0e4c419e4ed2af23574769815aaecbcd629c0
--- /dev/null
+++ b/mmpretrain/engine/hooks/precise_bn_hook.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501
+# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
+
+import itertools
+import logging
+from typing import List, Optional, Sequence, Union
+
+import mmengine
+import torch
+import torch.nn as nn
+from mmengine.hooks import Hook
+from mmengine.logging import print_log
+from mmengine.model import is_model_wrapper
+from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner
+from mmengine.utils import ProgressBar
+from torch.functional import Tensor
+from torch.nn import GroupNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.instancenorm import _InstanceNorm
+from torch.utils.data import DataLoader
+
+from mmpretrain.registry import HOOKS
+
+DATA_BATCH = Optional[Sequence[dict]]
+
+
+def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
+ """Performs the scaled all_reduce operation on the provided tensors.
+
+ The input tensors are modified in-place. Currently supports only the sum
+ reduction operator. The reduced values are scaled by the inverse size of
+ the process group.
+
+ Args:
+ tensors (List[torch.Tensor]): The tensors to process.
+ num_gpus (int): The number of gpus to use
+ Returns:
+ List[torch.Tensor]: The processed tensors.
+ """
+ # There is no need for reduction in the single-proc case
+ if num_gpus == 1:
+ return tensors
+ # Queue the reductions
+ reductions = []
+ for tensor in tensors:
+ reduction = torch.distributed.all_reduce(tensor, async_op=True)
+ reductions.append(reduction)
+ # Wait for reductions to finish
+ for reduction in reductions:
+ reduction.wait()
+ # Scale the results
+ for tensor in tensors:
+ tensor.mul_(1.0 / num_gpus)
+ return tensors
+
+
+@torch.no_grad()
+def update_bn_stats(
+ model: nn.Module,
+ loader: DataLoader,
+ num_samples: int = 8192,
+ logger: Optional[Union[logging.Logger, str]] = None) -> None:
+ """Computes precise BN stats on training data.
+
+ Args:
+ model (nn.module): The model whose bn stats will be recomputed.
+ loader (DataLoader): PyTorch dataloader._dataloader
+ num_samples (int): The number of samples to update the bn stats.
+ Defaults to 8192.
+ logger (logging.Logger or str, optional): If the type of logger is
+ ``logging.Logger``, we directly use logger to log messages.
+ Some special loggers are:
+ - "silent": No message will be printed.
+ - "current": Use latest created logger to log message.
+ - other str: Instance name of logger. The corresponding logger
+ will log message if it has been created, otherwise will raise a
+ `ValueError`.
+ - None: The `print()` method will be used to print log messages.
+ """
+ if is_model_wrapper(model):
+ model = model.module
+
+ # get dist info
+ rank, world_size = mmengine.dist.get_dist_info()
+ # Compute the number of mini-batches to use, if the size of dataloader is
+ # less than num_iters, use all the samples in dataloader.
+ num_iter = num_samples // (loader.batch_size * world_size)
+ num_iter = min(num_iter, len(loader))
+ # Retrieve the BN layers
+ bn_layers = [
+ m for m in model.modules()
+ if m.training and isinstance(m, (_BatchNorm))
+ ]
+ if len(bn_layers) == 0:
+ print_log('No BN found in model', logger=logger, level=logging.WARNING)
+ return
+ print_log(
+ f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger)
+
+ # Finds all the other norm layers with training=True.
+ other_norm_layers = [
+ m for m in model.modules()
+ if m.training and isinstance(m, (_InstanceNorm, GroupNorm))
+ ]
+ if len(other_norm_layers) > 0:
+ print_log(
+ 'IN/GN stats will not be updated in PreciseHook.',
+ logger=logger,
+ level=logging.INFO)
+
+ # Initialize BN stats storage for computing
+ # mean(mean(batch)) and mean(var(batch))
+ running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
+ running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers]
+ # Remember momentum values
+ momentums = [bn.momentum for bn in bn_layers]
+ # Set momentum to 1.0 to compute BN stats that reflect the current batch
+ for bn in bn_layers:
+ bn.momentum = 1.0
+ # Average the BN stats for each BN layer over the batches
+ if rank == 0:
+ prog_bar = ProgressBar(num_iter)
+
+ for data in itertools.islice(loader, num_iter):
+ data = model.data_preprocessor(data, False)
+ model(**data)
+
+ for i, bn in enumerate(bn_layers):
+ running_means[i] += bn.running_mean / num_iter
+ running_vars[i] += bn.running_var / num_iter
+ if rank == 0:
+ prog_bar.update()
+
+ # Sync BN stats across GPUs (no reduction if 1 GPU used)
+ running_means = scaled_all_reduce(running_means, world_size)
+ running_vars = scaled_all_reduce(running_vars, world_size)
+ # Set BN stats and restore original momentum values
+ for i, bn in enumerate(bn_layers):
+ bn.running_mean = running_means[i]
+ bn.running_var = running_vars[i]
+ bn.momentum = momentums[i]
+
+
+@HOOKS.register_module()
+class PreciseBNHook(Hook):
+ """Precise BN hook.
+
+ Recompute and update the batch norm stats to make them more precise. During
+ training both BN stats and the weight are changing after every iteration,
+ so the running average can not precisely reflect the actual stats of the
+ current model.
+
+ With this hook, the BN stats are recomputed with fixed weights, to make the
+ running average more precise. Specifically, it computes the true average of
+ per-batch mean/variance instead of the running average. See Sec. 3 of the
+ paper `Rethinking Batch in BatchNorm `
+ for details.
+
+ This hook will update BN stats, so it should be executed before
+ ``CheckpointHook`` and ``EMAHook``, generally set its priority to
+ "ABOVE_NORMAL".
+
+ Args:
+ num_samples (int): The number of samples to update the bn stats.
+ Defaults to 8192.
+ interval (int): Perform precise bn interval. If the train loop is
+ `EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the
+ train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is
+ 'iter'. Defaults to 1.
+ """
+
+ def __init__(self, num_samples: int = 8192, interval: int = 1) -> None:
+ assert interval > 0 and num_samples > 0, "'interval' and " \
+ "'num_samples' must be bigger than 0."
+
+ self.interval = interval
+ self.num_samples = num_samples
+
+ def _perform_precise_bn(self, runner: Runner) -> None:
+ """perform precise bn."""
+ print_log(
+ f'Running Precise BN for {self.num_samples} samples...',
+ logger=runner.logger)
+ update_bn_stats(
+ runner.model,
+ runner.train_loop.dataloader,
+ self.num_samples,
+ logger=runner.logger)
+ print_log('Finish Precise BN, BN stats updated.', logger=runner.logger)
+
+ def after_train_epoch(self, runner: Runner) -> None:
+ """Calculate prcise BN and broadcast BN stats across GPUs.
+
+ Args:
+ runner (obj:`Runner`): The runner of the training process.
+ """
+ # if use `EpochBasedTrainLoop``, do perform precise every
+ # `self.interval` epochs.
+ if isinstance(runner.train_loop,
+ EpochBasedTrainLoop) and self.every_n_epochs(
+ runner, self.interval):
+ self._perform_precise_bn(runner)
+
+ def after_train_iter(self,
+ runner,
+ batch_idx: int,
+ data_batch: DATA_BATCH = None,
+ outputs: Optional[dict] = None) -> None:
+ """Calculate prcise BN and broadcast BN stats across GPUs.
+
+ Args:
+ runner (obj:`Runner`): The runner of the training process.
+ batch_idx (int): The index of the current batch in the train loop.
+ data_batch (Sequence[dict], optional): Data from dataloader.
+ Defaults to None.
+ """
+ # if use `IterBasedTrainLoop``, do perform precise every
+ # `self.interval` iters.
+ if isinstance(runner.train_loop,
+ IterBasedTrainLoop) and self.every_n_train_iters(
+ runner, self.interval):
+ self._perform_precise_bn(runner)
diff --git a/mmpretrain/engine/hooks/retriever_hooks.py b/mmpretrain/engine/hooks/retriever_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bd7c7aaff3175491b1ea1508e33b07b7c2ea8d4
--- /dev/null
+++ b/mmpretrain/engine/hooks/retriever_hooks.py
@@ -0,0 +1,32 @@
+# Copyright (c) OpenMMLab. All rights reserved
+import warnings
+
+from mmengine.hooks import Hook
+from mmengine.model import is_model_wrapper
+
+from mmpretrain.models import BaseRetriever
+from mmpretrain.registry import HOOKS
+
+
+@HOOKS.register_module()
+class PrepareProtoBeforeValLoopHook(Hook):
+ """The hook to prepare the prototype in retrievers.
+
+ Since the encoders of the retriever changes during training, the prototype
+ changes accordingly. So the `prototype_vecs` needs to be regenerated before
+ validation loop.
+ """
+
+ def before_val(self, runner) -> None:
+ model = runner.model
+ if is_model_wrapper(model):
+ model = model.module
+
+ if isinstance(model, BaseRetriever):
+ if hasattr(model, 'prepare_prototype'):
+ model.prepare_prototype()
+ else:
+ warnings.warn(
+ 'Only the `mmpretrain.models.retrievers.BaseRetriever` '
+ 'can execute `PrepareRetrieverPrototypeHook`, but got '
+ f'`{type(model)}`')
diff --git a/mmpretrain/engine/hooks/simsiam_hook.py b/mmpretrain/engine/hooks/simsiam_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabc4faca02bb78b92c39de68fa8a18e56d544f5
--- /dev/null
+++ b/mmpretrain/engine/hooks/simsiam_hook.py
@@ -0,0 +1,48 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+from mmengine.hooks import Hook
+
+from mmpretrain.registry import HOOKS
+
+
+@HOOKS.register_module()
+class SimSiamHook(Hook):
+ """Hook for SimSiam.
+
+ This hook is for SimSiam to fix learning rate of predictor.
+
+ Args:
+ fix_pred_lr (bool): whether to fix the lr of predictor or not.
+ lr (float): the value of fixed lr.
+ adjust_by_epoch (bool, optional): whether to set lr by epoch or iter.
+ Defaults to True.
+ """
+
+ def __init__(self,
+ fix_pred_lr: bool,
+ lr: float,
+ adjust_by_epoch: Optional[bool] = True) -> None:
+ self.fix_pred_lr = fix_pred_lr
+ self.lr = lr
+ self.adjust_by_epoch = adjust_by_epoch
+
+ def before_train_iter(self,
+ runner,
+ batch_idx: int,
+ data_batch: Optional[Sequence[dict]] = None) -> None:
+ """fix lr of predictor by iter."""
+ if self.adjust_by_epoch:
+ return
+ else:
+ if self.fix_pred_lr:
+ for param_group in runner.optim_wrapper.optimizer.param_groups:
+ if 'fix_lr' in param_group and param_group['fix_lr']:
+ param_group['lr'] = self.lr
+
+ def before_train_epoch(self, runner) -> None:
+ """fix lr of predictor by epoch."""
+ if self.fix_pred_lr:
+ for param_group in runner.optim_wrapper.optimizer.param_groups:
+ if 'fix_lr' in param_group and param_group['fix_lr']:
+ param_group['lr'] = self.lr
diff --git a/mmpretrain/engine/hooks/swav_hook.py b/mmpretrain/engine/hooks/swav_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c82ad1e0c47114cccdd90f26a3d6c086e36d18
--- /dev/null
+++ b/mmpretrain/engine/hooks/swav_hook.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Dict, List, Optional, Sequence
+
+import torch
+from mmengine.dist import get_rank, get_world_size, is_distributed
+from mmengine.hooks import Hook
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import HOOKS
+from mmpretrain.utils import get_ori_model
+
+
+@HOOKS.register_module()
+class SwAVHook(Hook):
+ """Hook for SwAV.
+
+ This hook builds the queue in SwAV according to ``epoch_queue_starts``.
+ The queue will be saved in ``runner.work_dir`` or loaded at start epoch
+ if the path folder has queues saved before.
+
+ Args:
+ batch_size (int): the batch size per GPU for computing.
+ epoch_queue_starts (int, optional): from this epoch, starts to use the
+ queue. Defaults to 15.
+ crops_for_assign (list[int], optional): list of crops id used for
+ computing assignments. Defaults to [0, 1].
+ feat_dim (int, optional): feature dimension of output vector.
+ Defaults to 128.
+ queue_length (int, optional): length of the queue (0 for no queue).
+ Defaults to 0.
+ interval (int, optional): the interval to save the queue.
+ Defaults to 1.
+ frozen_layers_cfg (dict, optional): Dict to config frozen layers.
+ The key-value pair is layer name and its frozen iters. If frozen,
+ the layers don't need gradient. Defaults to dict().
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ epoch_queue_starts: Optional[int] = 15,
+ crops_for_assign: Optional[List[int]] = [0, 1],
+ feat_dim: Optional[int] = 128,
+ queue_length: Optional[int] = 0,
+ interval: Optional[int] = 1,
+ frozen_layers_cfg: Optional[Dict] = dict()
+ ) -> None:
+ self.batch_size = batch_size * get_world_size()
+ self.epoch_queue_starts = epoch_queue_starts
+ self.crops_for_assign = crops_for_assign
+ self.feat_dim = feat_dim
+ self.queue_length = queue_length
+ self.interval = interval
+ self.frozen_layers_cfg = frozen_layers_cfg
+ self.requires_grad = True
+ self.queue = None
+
+ def before_run(self, runner) -> None:
+ """Check whether the queues exist locally or not."""
+ if is_distributed():
+ self.queue_path = osp.join(runner.work_dir,
+ 'queue' + str(get_rank()) + '.pth')
+ else:
+ self.queue_path = osp.join(runner.work_dir, 'queue.pth')
+
+ # load the queues if queues exist locally
+ if osp.isfile(self.queue_path):
+ self.queue = torch.load(self.queue_path)['queue']
+ get_ori_model(runner.model).head.loss_module.queue = self.queue
+ MMLogger.get_current_instance().info(
+ f'Load queue from file: {self.queue_path}')
+
+ # the queue needs to be divisible by the batch size
+ self.queue_length -= self.queue_length % self.batch_size
+
+ def before_train_iter(self,
+ runner,
+ batch_idx: int,
+ data_batch: Optional[Sequence[dict]] = None) -> None:
+ """Freeze layers before specific iters according to the config."""
+ for layer, frozen_iters in self.frozen_layers_cfg.items():
+ if runner.iter < frozen_iters and self.requires_grad:
+ self.requires_grad = False
+ for name, p in get_ori_model(runner.model).named_parameters():
+ if layer in name:
+ p.requires_grad = False
+ elif runner.iter >= frozen_iters and not self.requires_grad:
+ self.requires_grad = True
+ for name, p in get_ori_model(runner.model).named_parameters():
+ if layer in name:
+ p.requires_grad = True
+
+ def before_train_epoch(self, runner) -> None:
+ """Check the queues' state."""
+ # optionally starts a queue
+ if self.queue_length > 0 \
+ and runner.epoch >= self.epoch_queue_starts \
+ and self.queue is None:
+ self.queue = torch.zeros(
+ len(self.crops_for_assign),
+ self.queue_length // runner.world_size,
+ self.feat_dim,
+ ).cuda()
+
+ # set the boolean type of use_the_queue
+ get_ori_model(runner.model).head.loss_module.queue = self.queue
+ get_ori_model(runner.model).head.loss_module.use_queue = False
+
+ def after_train_epoch(self, runner) -> None:
+ """Save the queues locally."""
+ self.queue = get_ori_model(runner.model).head.loss_module.queue
+
+ if self.queue is not None and self.every_n_epochs(
+ runner, self.interval):
+ torch.save({'queue': self.queue}, self.queue_path)
diff --git a/mmpretrain/engine/hooks/switch_recipe_hook.py b/mmpretrain/engine/hooks/switch_recipe_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..914b9572eb22d2cd2f54c519273c86baf2e0894d
--- /dev/null
+++ b/mmpretrain/engine/hooks/switch_recipe_hook.py
@@ -0,0 +1,169 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from copy import deepcopy
+
+from mmcv.transforms import Compose
+from mmengine.hooks import Hook
+from mmengine.model import is_model_wrapper
+
+from mmpretrain.models.utils import RandomBatchAugment
+from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
+
+
+@HOOKS.register_module()
+class SwitchRecipeHook(Hook):
+ """switch recipe during the training loop, including train pipeline, batch
+ augments and loss currently.
+
+ Args:
+ schedule (list): Every item of the schedule list should be a dict, and
+ the dict should have ``action_epoch`` and some of
+ ``train_pipeline``, ``train_augments`` and ``loss`` keys:
+
+ - ``action_epoch`` (int): switch training recipe at which epoch.
+ - ``train_pipeline`` (list, optional): The new data pipeline of the
+ train dataset. If not specified, keep the original settings.
+ - ``batch_augments`` (dict | None, optional): The new batch
+ augmentations of during training. See :mod:`Batch Augmentations
+ ` for more details.
+ If None, disable batch augmentations. If not specified, keep the
+ original settings.
+ - ``loss`` (dict, optional): The new loss module config. If not
+ specified, keep the original settings.
+
+ Example:
+ To use this hook in config files.
+
+ .. code:: python
+
+ custom_hooks = [
+ dict(
+ type='SwitchRecipeHook',
+ schedule=[
+ dict(
+ action_epoch=30,
+ train_pipeline=pipeline_after_30e,
+ batch_augments=batch_augments_after_30e,
+ loss=loss_after_30e,
+ ),
+ dict(
+ action_epoch=60,
+ # Disable batch augmentations after 60e
+ # and keep other settings.
+ batch_augments=None,
+ ),
+ ]
+ )
+ ]
+ """
+ priority = 'NORMAL'
+
+ def __init__(self, schedule):
+ recipes = {}
+ for recipe in schedule:
+ assert 'action_epoch' in recipe, \
+ 'Please set `action_epoch` in every item ' \
+ 'of the `schedule` in the SwitchRecipeHook.'
+ recipe = deepcopy(recipe)
+ if 'train_pipeline' in recipe:
+ recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
+ if 'batch_augments' in recipe:
+ batch_augments = recipe['batch_augments']
+ if isinstance(batch_augments, dict):
+ batch_augments = RandomBatchAugment(**batch_augments)
+ recipe['batch_augments'] = batch_augments
+ if 'loss' in recipe:
+ loss = recipe['loss']
+ if isinstance(loss, dict):
+ loss = MODELS.build(loss)
+ recipe['loss'] = loss
+
+ action_epoch = recipe.pop('action_epoch')
+ assert action_epoch not in recipes, \
+ f'The `action_epoch` {action_epoch} is repeated ' \
+ 'in the SwitchRecipeHook.'
+ recipes[action_epoch] = recipe
+ self.schedule = OrderedDict(sorted(recipes.items()))
+
+ def before_train(self, runner) -> None:
+ """before run setting. If resume form a checkpoint, do all switch
+ before the current epoch.
+
+ Args:
+ runner (Runner): The runner of the training, validation or testing
+ process.
+ """
+ if runner._resume:
+ for action_epoch, recipe in self.schedule.items():
+ if action_epoch >= runner.epoch + 1:
+ break
+ self._do_switch(runner, recipe,
+ f' (resume recipe of epoch {action_epoch})')
+
+ def before_train_epoch(self, runner):
+ """do before train epoch."""
+ recipe = self.schedule.get(runner.epoch + 1, None)
+ if recipe is not None:
+ self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')
+
+ def _do_switch(self, runner, recipe, extra_info=''):
+ """do the switch aug process."""
+ if 'batch_augments' in recipe:
+ self._switch_batch_augments(runner, recipe['batch_augments'])
+ runner.logger.info(f'Switch batch augments{extra_info}.')
+
+ if 'train_pipeline' in recipe:
+ self._switch_train_pipeline(runner, recipe['train_pipeline'])
+ runner.logger.info(f'Switch train pipeline{extra_info}.')
+
+ if 'loss' in recipe:
+ self._switch_loss(runner, recipe['loss'])
+ runner.logger.info(f'Switch loss{extra_info}.')
+
+ @staticmethod
+ def _switch_batch_augments(runner, batch_augments):
+ """switch the train augments."""
+ model = runner.model
+ if is_model_wrapper(model):
+ model = model.module
+
+ model.data_preprocessor.batch_augments = batch_augments
+
+ @staticmethod
+ def _switch_train_pipeline(runner, train_pipeline):
+ """switch the train loader dataset pipeline."""
+
+ def switch_pipeline(dataset, pipeline):
+ if hasattr(dataset, 'pipeline'):
+ # for usual dataset
+ dataset.pipeline = pipeline
+ elif hasattr(dataset, 'datasets'):
+ # for concat dataset wrapper
+ for ds in dataset.datasets:
+ switch_pipeline(ds, pipeline)
+ elif hasattr(dataset, 'dataset'):
+ # for other dataset wrappers
+ switch_pipeline(dataset.dataset, pipeline)
+ else:
+ raise RuntimeError(
+ 'Cannot access the `pipeline` of the dataset.')
+
+ train_loader = runner.train_loop.dataloader
+ switch_pipeline(train_loader.dataset, train_pipeline)
+
+ # To restart the iterator of dataloader when `persistent_workers=True`
+ train_loader._iterator = None
+
+ @staticmethod
+ def _switch_loss(runner, loss_module):
+ """switch the loss module."""
+ model = runner.model
+ if is_model_wrapper(model, MODEL_WRAPPERS):
+ model = model.module
+
+ if hasattr(model, 'loss_module'):
+ model.loss_module = loss_module
+ elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
+ model.head.loss_module = loss_module
+ else:
+ raise RuntimeError('Cannot access the `loss_module` of the model.')
diff --git a/mmpretrain/engine/hooks/visualization_hook.py b/mmpretrain/engine/hooks/visualization_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d2230a79db971bef78d77bcf80c40365bddb15
--- /dev/null
+++ b/mmpretrain/engine/hooks/visualization_hook.py
@@ -0,0 +1,126 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import os.path as osp
+from typing import Optional, Sequence
+
+from mmengine.fileio import join_path
+from mmengine.hooks import Hook
+from mmengine.runner import EpochBasedTrainLoop, Runner
+from mmengine.visualization import Visualizer
+
+from mmpretrain.registry import HOOKS
+from mmpretrain.structures import DataSample
+
+
+@HOOKS.register_module()
+class VisualizationHook(Hook):
+ """Classification Visualization Hook. Used to visualize validation and
+ testing prediction results.
+
+ - If ``out_dir`` is specified, all storage backends are ignored
+ and save the image to the ``out_dir``.
+ - If ``show`` is True, plot the result image in a window, please
+ confirm you are able to access the graphical interface.
+
+ Args:
+ enable (bool): Whether to enable this hook. Defaults to False.
+ interval (int): The interval of samples to visualize. Defaults to 5000.
+ show (bool): Whether to display the drawn image. Defaults to False.
+ out_dir (str, optional): directory where painted images will be saved
+ in the testing process. If None, handle with the backends of the
+ visualizer. Defaults to None.
+ **kwargs: other keyword arguments of
+ :meth:`mmpretrain.visualization.UniversalVisualizer.visualize_cls`.
+ """
+
+ def __init__(self,
+ enable=False,
+ interval: int = 5000,
+ show: bool = False,
+ out_dir: Optional[str] = None,
+ **kwargs):
+ self._visualizer: Visualizer = Visualizer.get_current_instance()
+
+ self.enable = enable
+ self.interval = interval
+ self.show = show
+ self.out_dir = out_dir
+
+ self.draw_args = {**kwargs, 'show': show}
+
+ def _draw_samples(self,
+ batch_idx: int,
+ data_batch: dict,
+ data_samples: Sequence[DataSample],
+ step: int = 0) -> None:
+ """Visualize every ``self.interval`` samples from a data batch.
+
+ Args:
+ batch_idx (int): The index of the current batch in the val loop.
+ data_batch (dict): Data from dataloader.
+ outputs (Sequence[:obj:`DataSample`]): Outputs from model.
+ step (int): Global step value to record. Defaults to 0.
+ """
+ if self.enable is False:
+ return
+
+ batch_size = len(data_samples)
+ images = data_batch['inputs']
+ start_idx = batch_size * batch_idx
+ end_idx = start_idx + batch_size
+
+ # The first index divisible by the interval, after the start index
+ first_sample_id = math.ceil(start_idx / self.interval) * self.interval
+
+ for sample_id in range(first_sample_id, end_idx, self.interval):
+ image = images[sample_id - start_idx]
+ image = image.permute(1, 2, 0).cpu().numpy().astype('uint8')
+
+ data_sample = data_samples[sample_id - start_idx]
+ if 'img_path' in data_sample:
+ # osp.basename works on different platforms even file clients.
+ sample_name = osp.basename(data_sample.get('img_path'))
+ else:
+ sample_name = str(sample_id)
+
+ draw_args = self.draw_args
+ if self.out_dir is not None:
+ draw_args['out_file'] = join_path(self.out_dir,
+ f'{sample_name}_{step}.png')
+
+ self._visualizer.visualize_cls(
+ image=image,
+ data_sample=data_sample,
+ step=step,
+ name=sample_name,
+ **self.draw_args,
+ )
+
+ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
+ outputs: Sequence[DataSample]) -> None:
+ """Visualize every ``self.interval`` samples during validation.
+
+ Args:
+ runner (:obj:`Runner`): The runner of the validation process.
+ batch_idx (int): The index of the current batch in the val loop.
+ data_batch (dict): Data from dataloader.
+ outputs (Sequence[:obj:`DataSample`]): Outputs from model.
+ """
+ if isinstance(runner.train_loop, EpochBasedTrainLoop):
+ step = runner.epoch
+ else:
+ step = runner.iter
+
+ self._draw_samples(batch_idx, data_batch, outputs, step=step)
+
+ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
+ outputs: Sequence[DataSample]) -> None:
+ """Visualize every ``self.interval`` samples during test.
+
+ Args:
+ runner (:obj:`Runner`): The runner of the testing process.
+ batch_idx (int): The index of the current batch in the test loop.
+ data_batch (dict): Data from dataloader.
+ outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
+ """
+ self._draw_samples(batch_idx, data_batch, outputs, step=0)
diff --git a/mmpretrain/engine/hooks/warmup_param_hook.py b/mmpretrain/engine/hooks/warmup_param_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45d8918dbbcb9cf5d12c252621908f0b6c1f251
--- /dev/null
+++ b/mmpretrain/engine/hooks/warmup_param_hook.py
@@ -0,0 +1,66 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import operator as op
+from typing import Any, Optional, Union
+
+from mmengine.hooks import Hook
+
+from mmpretrain.registry import HOOKS
+from mmpretrain.utils import get_ori_model
+
+
+@HOOKS.register_module()
+class WarmupParamHook(Hook):
+ """This is a hook used for changing the parameters other than optimizations
+ that need to warmup inside the module.
+
+ This hook can extend with more detailed warmup rule if necessary.
+
+ Args:
+ param_name (str): The parameter name that needs to be altered.
+ module_name (str): Module name that belongs to the model. Such as
+ `head`, `head.loss`, etc.
+ warmup_epochs (int): The warmup epochs for this parameter.
+ """
+
+ def __init__(
+ self,
+ param_name: str,
+ module_name: str,
+ warmup_epochs: int,
+ ) -> None:
+ self.param_name = param_name
+ self.warmup_epochs = warmup_epochs
+ # getter for module which saves the changed parameter
+ self.module_getter = op.attrgetter(module_name)
+
+ def get_param(self, runner) -> Any:
+ """Get the parameter."""
+ try:
+ module = self.module_getter(get_ori_model(runner.model))
+ return getattr(module, self.param_name)
+ except AttributeError as e:
+ raise AttributeError(f'{e}. Please check hook settings.')
+
+ def set_param(self, runner, value) -> None:
+ """Set the parameter."""
+ try:
+ module = self.module_getter(get_ori_model(runner.model))
+ setattr(module, self.param_name, value)
+ except AttributeError as e:
+ raise AttributeError(f'{e}. Please check hook settings.')
+
+ def before_train(self, runner) -> None:
+ """Get the original value before train."""
+ self.ori_val = self.get_param(runner)
+
+ def before_train_iter(
+ self,
+ runner,
+ batch_idx: int,
+ data_batch: Optional[Union[dict, tuple, list]] = None) -> None:
+ """Set the warmup value before each train iter."""
+ cur_iter = runner.iter
+ iters_per_epoch = runner.max_iters / runner.max_epochs
+ new_val = self.ori_val * min(
+ 1, cur_iter / (self.warmup_epochs * iters_per_epoch))
+ self.set_param(runner, new_val)
diff --git a/mmpretrain/engine/optimizers/__init__.py b/mmpretrain/engine/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd53a37630b2a0dfbb69b1020518b9ec4ff03715
--- /dev/null
+++ b/mmpretrain/engine/optimizers/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .adan_t import Adan
+from .lamb import Lamb
+from .lars import LARS
+from .layer_decay_optim_wrapper_constructor import \
+ LearningRateDecayOptimWrapperConstructor
+
+__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor']
diff --git a/mmpretrain/engine/optimizers/adan_t.py b/mmpretrain/engine/optimizers/adan_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..571a71b6fe561fb33053af2fd6d2161a775918e4
--- /dev/null
+++ b/mmpretrain/engine/optimizers/adan_t.py
@@ -0,0 +1,312 @@
+# Copyright 2022 Garena Online Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List
+
+import torch
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+
+from mmpretrain.registry import OPTIMIZERS
+
+
+@OPTIMIZERS.register_module()
+class Adan(Optimizer):
+ """Implements a pytorch variant of Adan.
+
+ Adan was proposed in
+ Adan : Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. # noqa
+ https://arxiv.org/abs/2208.06677
+ Arguments:
+ params (iterable): iterable of parameters to optimize
+ or dicts defining parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float, flot], optional): coefficients used
+ for computing running averages of gradient.
+ (default: (0.98, 0.92, 0.99))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): decoupled weight decay
+ (L2 penalty) (default: 0)
+ max_grad_norm (float, optional): value used to clip
+ global grad norm (default: 0.0 no clip)
+ no_prox (bool): how to perform the decoupled weight decay
+ (default: False)
+ foreach (bool): if True would use torch._foreach implementation.
+ It's faster but uses slightly more memory.
+ """
+
+ def __init__(self,
+ params,
+ lr=1e-3,
+ betas=(0.98, 0.92, 0.99),
+ eps=1e-8,
+ weight_decay=0.0,
+ max_grad_norm=0.0,
+ no_prox=False,
+ foreach: bool = True):
+ if not 0.0 <= max_grad_norm:
+ raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm))
+ if not 0.0 <= lr:
+ raise ValueError('Invalid learning rate: {}'.format(lr))
+ if not 0.0 <= eps:
+ raise ValueError('Invalid epsilon value: {}'.format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError('Invalid beta parameter at index 0: {}'.format(
+ betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError('Invalid beta parameter at index 1: {}'.format(
+ betas[1]))
+ if not 0.0 <= betas[2] < 1.0:
+ raise ValueError('Invalid beta parameter at index 2: {}'.format(
+ betas[2]))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ max_grad_norm=max_grad_norm,
+ no_prox=no_prox,
+ foreach=foreach)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Adan, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('no_prox', False)
+
+ @torch.no_grad()
+ def restart_opt(self):
+ for group in self.param_groups:
+ group['step'] = 0
+ for p in group['params']:
+ if p.requires_grad:
+ state = self.state[p]
+ # State initialization
+
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ # Exponential moving average of gradient difference
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ @torch.no_grad()
+ def step(self):
+ """Performs a single optimization step."""
+ if self.defaults['max_grad_norm'] > 0:
+ device = self.param_groups[0]['params'][0].device
+ global_grad_norm = torch.zeros(1, device=device)
+
+ max_grad_norm = torch.tensor(
+ self.defaults['max_grad_norm'], device=device)
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is not None:
+ grad = p.grad
+ global_grad_norm.add_(grad.pow(2).sum())
+
+ global_grad_norm = torch.sqrt(global_grad_norm) + group['eps']
+
+ clip_global_grad_norm = \
+ torch.clamp(max_grad_norm / global_grad_norm, max=1.0)
+ else:
+ clip_global_grad_norm = 1.0
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ exp_avg_diffs = []
+ pre_grads = []
+
+ beta1, beta2, beta3 = group['betas']
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support
+ # by making it tensor, or pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ bias_correction1 = 1.0 - beta1**group['step']
+ bias_correction2 = 1.0 - beta2**group['step']
+ bias_correction3 = 1.0 - beta3**group['step']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ grads.append(p.grad)
+
+ state = self.state[p]
+ if len(state) == 0:
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ if 'pre_grad' not in state or group['step'] == 1:
+ # at first step grad wouldn't be clipped
+ # by `clip_global_grad_norm`
+ # this is only to simplify implementation
+ state['pre_grad'] = p.grad
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ exp_avg_diffs.append(state['exp_avg_diff'])
+ pre_grads.append(state['pre_grad'])
+
+ kwargs = dict(
+ params=params_with_grad,
+ grads=grads,
+ exp_avgs=exp_avgs,
+ exp_avg_sqs=exp_avg_sqs,
+ exp_avg_diffs=exp_avg_diffs,
+ pre_grads=pre_grads,
+ beta1=beta1,
+ beta2=beta2,
+ beta3=beta3,
+ bias_correction1=bias_correction1,
+ bias_correction2=bias_correction2,
+ bias_correction3_sqrt=math.sqrt(bias_correction3),
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ no_prox=group['no_prox'],
+ clip_global_grad_norm=clip_global_grad_norm,
+ )
+ if group['foreach']:
+ copy_grads = _multi_tensor_adan(**kwargs)
+ else:
+ copy_grads = _single_tensor_adan(**kwargs)
+
+ for p, copy_grad in zip(params_with_grad, copy_grads):
+ self.state[p]['pre_grad'] = copy_grad
+
+
+def _single_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ copy_grads = []
+ for i, param in enumerate(params):
+ grad = grads[i]
+ exp_avg = exp_avgs[i]
+ exp_avg_sq = exp_avg_sqs[i]
+ exp_avg_diff = exp_avg_diffs[i]
+ pre_grad = pre_grads[i]
+
+ grad = grad.mul_(clip_global_grad_norm)
+ copy_grads.append(grad.clone())
+
+ diff = grad - pre_grad
+ update = grad + beta2 * diff
+
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
+ exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t
+ exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t
+
+ denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps)
+ update = exp_avg / bias_correction1
+ update.add_(beta2 * exp_avg_diff / bias_correction2).div_(denom)
+
+ if no_prox:
+ param.mul_(1 - lr * weight_decay)
+ param.add_(update, alpha=-lr)
+ else:
+ param.add_(update, alpha=-lr)
+ param.div_(1 + lr * weight_decay)
+ return copy_grads
+
+
+def _multi_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ if clip_global_grad_norm < 1.0:
+ torch._foreach_mul_(grads, clip_global_grad_norm.item())
+ copy_grads = [g.clone() for g in grads]
+
+ diff = torch._foreach_sub(grads, pre_grads)
+ # NOTE: line below while looking identical gives different result,
+ # due to float precision errors.
+ # using mul+add produces identical results to single-tensor,
+ # using add+alpha doesn't
+ # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2))
+ update = torch._foreach_add(grads, diff, alpha=beta2)
+
+ torch._foreach_mul_(exp_avgs, beta1)
+ torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
+
+ torch._foreach_mul_(exp_avg_diffs, beta2)
+ torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t
+
+ torch._foreach_mul_(exp_avg_sqs, beta3)
+ torch._foreach_addcmul_(
+ exp_avg_sqs, update, update, value=1 - beta3) # n_t
+
+ denom = torch._foreach_sqrt(exp_avg_sqs)
+ torch._foreach_div_(denom, bias_correction3_sqrt)
+ torch._foreach_add_(denom, eps)
+
+ update = torch._foreach_div(exp_avgs, bias_correction1)
+ # NOTE: same issue as above.
+ # beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) # noqa
+ # using faster version by default. uncomment for tests to pass
+ # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) # noqa
+ torch._foreach_add_(
+ update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2))
+ torch._foreach_div_(update, denom)
+
+ if no_prox:
+ torch._foreach_mul_(params, 1 - lr * weight_decay)
+ else:
+ torch._foreach_add_(params, update, alpha=-lr)
+ torch._foreach_div_(params, 1 + lr * weight_decay)
+ return copy_grads
diff --git a/mmpretrain/engine/optimizers/lamb.py b/mmpretrain/engine/optimizers/lamb.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b44a1c168e03fa7f569388beec206fe68c64749
--- /dev/null
+++ b/mmpretrain/engine/optimizers/lamb.py
@@ -0,0 +1,228 @@
+"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb.
+
+This optimizer code was adapted from the following (starting with latest)
+* https://github.com/HabanaAI/Model-References/blob/
+2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
+* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
+LanguageModeling/Transformer-XL/pytorch/lamb.py
+* https://github.com/cybertronai/pytorch-lamb
+
+Use FusedLamb if you can (GPU). The reason for including this variant of Lamb
+is to have a version that is
+similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or
+cannot install/use APEX.
+
+In addition to some cleanup, this Lamb impl has been modified to support
+PyTorch XLA and has been tested on TPU.
+
+Original copyrights for above sources are below.
+
+Modifications Copyright 2021 Ross Wightman
+"""
+# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
+
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2019 cybertronai
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import math
+
+import torch
+from torch.optim import Optimizer
+
+from mmpretrain.registry import OPTIMIZERS
+
+
+@OPTIMIZERS.register_module()
+class Lamb(Optimizer):
+ """A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer.
+
+ This class is copied from `timm`_. The LAMB was proposed in `Large Batch
+ Optimization for Deep Learning - Training BERT in 76 minutes`_.
+
+ .. _timm:
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
+ https://arxiv.org/abs/1904.00962
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its norm. (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
+ calculating running averages of gradient. (default: True)
+ max_grad_norm (float, optional): value used to clip global grad norm
+ (default: 1.0)
+ trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
+ always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
+ weight decay parameter (default: False)
+ """ # noqa: E501
+
+ def __init__(self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-6,
+ weight_decay=0.01,
+ grad_averaging=True,
+ max_grad_norm=1.0,
+ trust_clip=False,
+ always_adapt=False):
+ defaults = dict(
+ lr=lr,
+ bias_correction=bias_correction,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ grad_averaging=grad_averaging,
+ max_grad_norm=max_grad_norm,
+ trust_clip=trust_clip,
+ always_adapt=always_adapt)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ device = self.param_groups[0]['params'][0].device
+ one_tensor = torch.tensor(
+ 1.0, device=device
+ ) # because torch.where doesn't handle scalars correctly
+ global_grad_norm = torch.zeros(1, device=device)
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ 'Lamb does not support sparse gradients, consider '
+ 'SparseAdam instead.')
+ global_grad_norm.add_(grad.pow(2).sum())
+
+ global_grad_norm = torch.sqrt(global_grad_norm)
+ # FIXME it'd be nice to remove explicit tensor conversion of scalars
+ # when torch.where promotes
+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
+ max_grad_norm = torch.tensor(
+ self.defaults['max_grad_norm'], device=device)
+ clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm,
+ global_grad_norm / max_grad_norm,
+ one_tensor)
+
+ for group in self.param_groups:
+ bias_correction = 1 if group['bias_correction'] else 0
+ beta1, beta2 = group['betas']
+ grad_averaging = 1 if group['grad_averaging'] else 0
+ beta3 = 1 - beta1 if grad_averaging else 1.0
+
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support by making it tensor, or
+ # pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ if bias_correction:
+ bias_correction1 = 1 - beta1**group['step']
+ bias_correction2 = 1 - beta2**group['step']
+ else:
+ bias_correction1, bias_correction2 = 1.0, 1.0
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.div_(clip_global_grad_norm)
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ # Exponential moving average of gradient valuesa
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
+ exp_avg_sq.mul_(beta2).addcmul_(
+ grad, grad, value=1 - beta2) # v_t
+
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
+ group['eps'])
+ update = (exp_avg / bias_correction1).div_(denom)
+
+ weight_decay = group['weight_decay']
+ if weight_decay != 0:
+ update.add_(p, alpha=weight_decay)
+
+ if weight_decay != 0 or group['always_adapt']:
+ # Layer-wise LR adaptation. By default, skip adaptation on
+ # parameters that are
+ # excluded from weight decay, unless always_adapt == True,
+ # then always enabled.
+ w_norm = p.norm(2.0)
+ g_norm = update.norm(2.0)
+ # FIXME nested where required since logical and/or not
+ # working in PT XLA
+ trust_ratio = torch.where(
+ w_norm > 0,
+ torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
+ one_tensor,
+ )
+ if group['trust_clip']:
+ # LAMBC trust clipping, upper bound fixed at one
+ trust_ratio = torch.minimum(trust_ratio, one_tensor)
+ update.mul_(trust_ratio)
+
+ p.add_(update, alpha=-group['lr'])
+
+ return loss
diff --git a/mmpretrain/engine/optimizers/lars.py b/mmpretrain/engine/optimizers/lars.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e388878374e3d1e7408861a5f1830b00df5664b
--- /dev/null
+++ b/mmpretrain/engine/optimizers/lars.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Iterable
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+from mmpretrain.registry import OPTIMIZERS
+
+
+@OPTIMIZERS.register_module()
+class LARS(Optimizer):
+ """Implements layer-wise adaptive rate scaling for SGD.
+
+ Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
+ `Large Batch Training of Convolutional Networks:
+ `_.
+
+ Args:
+ params (Iterable): Iterable of parameters to optimize or dicts defining
+ parameter groups.
+ lr (float): Base learning rate.
+ momentum (float): Momentum factor. Defaults to 0.
+ weight_decay (float): Weight decay (L2 penalty). Defaults to 0.
+ dampening (float): Dampening for momentum. Defaults to 0.
+ eta (float): LARS coefficient. Defaults to 0.001.
+ nesterov (bool): Enables Nesterov momentum. Defaults to False.
+ eps (float): A small number to avoid dviding zero. Defaults to 1e-8.
+
+ Example:
+ >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9,
+ >>> weight_decay=1e-4, eta=1e-3)
+ >>> optimizer.zero_grad()
+ >>> loss_fn(model(input), target).backward()
+ >>> optimizer.step()
+ """
+
+ def __init__(self,
+ params: Iterable,
+ lr: float,
+ momentum: float = 0,
+ weight_decay: float = 0,
+ dampening: float = 0,
+ eta: float = 0.001,
+ nesterov: bool = False,
+ eps: float = 1e-8) -> None:
+ if not isinstance(lr, float) and lr < 0.0:
+ raise ValueError(f'Invalid learning rate: {lr}')
+ if momentum < 0.0:
+ raise ValueError(f'Invalid momentum value: {momentum}')
+ if weight_decay < 0.0:
+ raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+ if eta < 0.0:
+ raise ValueError(f'Invalid LARS coefficient value: {eta}')
+
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ eta=eta)
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError(
+ 'Nesterov momentum requires a momentum and zero dampening')
+
+ self.eps = eps
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state) -> None:
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('nesterov', False)
+
+ @torch.no_grad()
+ def step(self, closure=None) -> torch.Tensor:
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ eta = group['eta']
+ nesterov = group['nesterov']
+ lr = group['lr']
+ lars_exclude = group.get('lars_exclude', False)
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ d_p = p.grad
+
+ if lars_exclude:
+ local_lr = 1.
+ else:
+ weight_norm = torch.norm(p).item()
+ grad_norm = torch.norm(d_p).item()
+ if weight_norm != 0 and grad_norm != 0:
+ # Compute local learning rate for this layer
+ local_lr = eta * weight_norm / \
+ (grad_norm + weight_decay * weight_norm + self.eps)
+ else:
+ local_lr = 1.
+
+ actual_lr = local_lr * lr
+ d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr)
+ if momentum != 0:
+ param_state = self.state[p]
+ if 'momentum_buffer' not in param_state:
+ buf = param_state['momentum_buffer'] = \
+ torch.clone(d_p).detach()
+ else:
+ buf = param_state['momentum_buffer']
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+ if nesterov:
+ d_p = d_p.add(buf, alpha=momentum)
+ else:
+ d_p = buf
+ p.add_(-d_p)
+
+ return loss
diff --git a/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c6abc54a9f49cc789bf91d2bf74b0ec68902c4
--- /dev/null
+++ b/mmpretrain/engine/optimizers/layer_decay_optim_wrapper_constructor.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import defaultdict
+from typing import Callable, List, Optional
+
+from mmengine.logging import MMLogger
+from mmengine.optim import DefaultOptimWrapperConstructor
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
+from torch import nn
+from torch.nn import GroupNorm, LayerNorm
+
+from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS
+
+
+@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
+class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
+ """Different learning rates are set for different layers of backbone.
+
+ By default, each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain the following fields:
+
+ - ``layer_decay_rate`` (float): The learning rate of a parameter will
+ multiply it by multiple times according to the layer depth of the
+ parameter. Usually, it's less than 1, so that the earlier layers will
+ have a lower learning rate. Defaults to 1.
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in normalization layers).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization layers.
+ - ``flat_decay_mult`` (float): It will be multiplied to the weight
+ decay for all one-dimensional parameters
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
+ ignored. It should be a dict and may contain fields ``decay_mult``.
+ (The ``lr_mult`` is disabled in this constructor).
+
+ Example:
+
+ In the config file, you can use this constructor as below:
+
+ .. code:: python
+
+ optim_wrapper = dict(
+ optimizer=dict(
+ type='AdamW',
+ lr=4e-3,
+ weight_decay=0.05,
+ eps=1e-8,
+ betas=(0.9, 0.999)),
+ constructor='LearningRateDecayOptimWrapperConstructor',
+ paramwise_cfg=dict(
+ layer_decay_rate=0.75, # layer-wise lr decay factor
+ norm_decay_mult=0.,
+ flat_decay_mult=0.,
+ custom_keys={
+ '.cls_token': dict(decay_mult=0.0),
+ '.pos_embed': dict(decay_mult=0.0)
+ }))
+ """
+
+ def add_params(self,
+ params: List[dict],
+ module: nn.Module,
+ prefix: str = '',
+ get_layer_depth: Optional[Callable] = None,
+ **kwargs) -> None:
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (List[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ optimizer_cfg (dict): The configuration of optimizer.
+ prefix (str): The prefix of the module.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+ logger = MMLogger.get_current_instance()
+
+ # The model should have `get_layer_depth` method
+ if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
+ raise NotImplementedError('The layer-wise learning rate decay need'
+ f' the model {type(module)} has'
+ ' `get_layer_depth` method.')
+ else:
+ get_layer_depth = get_layer_depth or module.get_layer_depth
+
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
+ flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
+ decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
+
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ param_name = prefix + name
+ if not param.requires_grad:
+ continue
+
+ if self.base_wd is not None:
+ base_wd = self.base_wd
+ custom_key = next(
+ filter(lambda k: k in param_name, sorted_keys), None)
+ # custom parameters decay
+ if custom_key is not None:
+ custom_cfg = custom_keys[custom_key].copy()
+ decay_mult = custom_cfg.pop('decay_mult', 1.)
+
+ param_group['weight_decay'] = base_wd * decay_mult
+ # add custom settings to param_group
+ param_group.update(custom_cfg)
+ # norm decay
+ elif is_norm and norm_decay_mult is not None:
+ param_group['weight_decay'] = base_wd * norm_decay_mult
+ # bias decay
+ elif name == 'bias' and bias_decay_mult is not None:
+ param_group['weight_decay'] = base_wd * bias_decay_mult
+ # flatten parameters decay
+ elif param.ndim == 1 and flat_decay_mult is not None:
+ param_group['weight_decay'] = base_wd * flat_decay_mult
+ else:
+ param_group['weight_decay'] = base_wd
+
+ layer_id, max_id = get_layer_depth(param_name)
+ scale = decay_rate**(max_id - layer_id - 1)
+ param_group['lr'] = self.base_lr * scale
+ param_group['lr_scale'] = scale
+ param_group['layer_id'] = layer_id
+ param_group['param_name'] = param_name
+
+ params.append(param_group)
+
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}{child_name}.'
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ get_layer_depth=get_layer_depth,
+ )
+
+ if prefix == '':
+ layer_params = defaultdict(list)
+ for param in params:
+ layer_params[param['layer_id']].append(param)
+ for layer_id, layer_params in layer_params.items():
+ lr_scale = layer_params[0]['lr_scale']
+ lr = layer_params[0]['lr']
+ msg = [
+ f'layer {layer_id} params '
+ f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
+ ]
+ for param in layer_params:
+ msg.append(f'\t{param["param_name"]}: '
+ f'weight_decay={param["weight_decay"]:.3g}')
+ logger.debug('\n'.join(msg))
diff --git a/mmpretrain/engine/runners/__init__.py b/mmpretrain/engine/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..23206e1ea7c83fa1d547c677b3fe5203f8c5485f
--- /dev/null
+++ b/mmpretrain/engine/runners/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .retrieval_loop import RetrievalTestLoop, RetrievalValLoop
+
+__all__ = ['RetrievalTestLoop', 'RetrievalValLoop']
diff --git a/mmpretrain/engine/runners/retrieval_loop.py b/mmpretrain/engine/runners/retrieval_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15387eddeb9075c23949f95a77ed59006bb9a38
--- /dev/null
+++ b/mmpretrain/engine/runners/retrieval_loop.py
@@ -0,0 +1,168 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+from mmengine.model import is_model_wrapper
+from mmengine.runner import TestLoop, ValLoop, autocast
+
+from mmpretrain.registry import LOOPS
+
+
+@LOOPS.register_module()
+class RetrievalValLoop(ValLoop):
+ """Loop for multimodal retrieval val.
+
+ Args:
+ runner (Runner): A reference of runner.
+ dataloader (Dataloader or dict): A dataloader object or a dict to
+ build a dataloader.
+ evaluator (Evaluator or dict or list): Used for computing metrics.
+ fp16 (bool): Whether to enable fp16 valing. Defaults to
+ False.
+ """
+
+ def run(self) -> dict:
+ """Launch val."""
+ self.runner.call_hook('before_val')
+ self.runner.call_hook('before_val_epoch')
+ self.runner.model.eval()
+
+ feats_local = []
+ data_samples_local = []
+
+ for idx, data_batch in enumerate(self.dataloader):
+ with torch.no_grad():
+ self.runner.call_hook(
+ 'before_val_iter', batch_idx=idx, data_batch=data_batch)
+ # predictions should be sequence of BaseDataElement
+ with autocast(enabled=self.fp16):
+ if is_model_wrapper(self.runner.model):
+ data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501
+ else:
+ data_preprocessor = self.runner.model.data_preprocessor
+
+ # get features for retrieval instead of data samples
+ data_batch = data_preprocessor(data_batch, False)
+ feats = self.runner.model._run_forward(
+ data_batch, mode='tensor')
+ feats_local.append(feats)
+ data_samples_local.extend(data_batch['data_samples'])
+ self.runner.call_hook(
+ 'after_val_iter',
+ batch_idx=idx,
+ data_batch=data_batch,
+ outputs=feats)
+
+ # concatenate different features
+ feats_local = {
+ k: torch.cat([dic[k] for dic in feats_local])
+ for k in feats_local[0]
+ }
+
+ # get predictions
+ if is_model_wrapper(self.runner.model):
+ predict_all_fn = self.runner.model.module.predict_all
+ else:
+ predict_all_fn = self.runner.model.predict_all
+
+ img_size = self.dataloader.dataset.img_size
+ text_size = self.dataloader.dataset.text_size
+ with torch.no_grad():
+ i2t_data_samples, t2i_data_samples = predict_all_fn(
+ feats_local,
+ data_samples_local,
+ num_images=img_size,
+ num_texts=text_size,
+ )
+
+ # process in evaluator and compute metrics
+ self.evaluator.process(i2t_data_samples, None)
+ i2t_metrics = self.evaluator.evaluate(img_size)
+ i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
+ self.evaluator.process(t2i_data_samples, None)
+ t2i_metrics = self.evaluator.evaluate(text_size)
+ t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
+ metrics = {**i2t_metrics, **t2i_metrics}
+
+ self.runner.call_hook('after_val_epoch', metrics=metrics)
+ self.runner.call_hook('after_val')
+ return metrics
+
+
+@LOOPS.register_module()
+class RetrievalTestLoop(TestLoop):
+ """Loop for multimodal retrieval test.
+
+ Args:
+ runner (Runner): A reference of runner.
+ dataloader (Dataloader or dict): A dataloader object or a dict to
+ build a dataloader.
+ evaluator (Evaluator or dict or list): Used for computing metrics.
+ fp16 (bool): Whether to enable fp16 testing. Defaults to
+ False.
+ """
+
+ def run(self) -> dict:
+ """Launch test."""
+ self.runner.call_hook('before_test')
+ self.runner.call_hook('before_test_epoch')
+ self.runner.model.eval()
+
+ feats_local = []
+ data_samples_local = []
+
+ for idx, data_batch in enumerate(self.dataloader):
+ with torch.no_grad():
+ self.runner.call_hook(
+ 'before_test_iter', batch_idx=idx, data_batch=data_batch)
+ # predictions should be sequence of BaseDataElement
+ with autocast(enabled=self.fp16):
+ if is_model_wrapper(self.runner.model):
+ data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501
+ else:
+ data_preprocessor = self.runner.model.data_preprocessor
+ # get features for retrieval instead of data samples
+ data_batch = data_preprocessor(data_batch, False)
+ feats = self.runner.model._run_forward(
+ data_batch, mode='tensor')
+ feats_local.append(feats)
+ data_samples_local.extend(data_batch['data_samples'])
+ self.runner.call_hook(
+ 'after_test_iter',
+ batch_idx=idx,
+ data_batch=data_batch,
+ outputs=feats)
+
+ # concatenate different features
+ feats_local = {
+ k: torch.cat([dic[k] for dic in feats_local])
+ for k in feats_local[0]
+ }
+
+ # get predictions
+ if is_model_wrapper(self.runner.model):
+ predict_all_fn = self.runner.model.module.predict_all
+ else:
+ predict_all_fn = self.runner.model.predict_all
+
+ img_size = self.dataloader.dataset.img_size
+ text_size = self.dataloader.dataset.text_size
+ with torch.no_grad():
+ i2t_data_samples, t2i_data_samples = predict_all_fn(
+ feats_local,
+ data_samples_local,
+ num_images=img_size,
+ num_texts=text_size,
+ )
+
+ # process in evaluator and compute metrics
+ self.evaluator.process(i2t_data_samples, None)
+ i2t_metrics = self.evaluator.evaluate(img_size)
+ i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
+ self.evaluator.process(t2i_data_samples, None)
+ t2i_metrics = self.evaluator.evaluate(text_size)
+ t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
+ metrics = {**i2t_metrics, **t2i_metrics}
+
+ self.runner.call_hook('after_test_epoch', metrics=metrics)
+ self.runner.call_hook('after_test')
+ return metrics
diff --git a/mmpretrain/engine/schedulers/__init__.py b/mmpretrain/engine/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68b6a5477b84a53e060e0e6d43fdac830adebffb
--- /dev/null
+++ b/mmpretrain/engine/schedulers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .weight_decay_scheduler import CosineAnnealingWeightDecay
+
+__all__ = ['CosineAnnealingWeightDecay']
diff --git a/mmpretrain/engine/schedulers/weight_decay_scheduler.py b/mmpretrain/engine/schedulers/weight_decay_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e725a4c3f53856cf848ed7e6a225a178b36ab98
--- /dev/null
+++ b/mmpretrain/engine/schedulers/weight_decay_scheduler.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from mmengine.optim.scheduler import CosineAnnealingParamScheduler
+
+from mmpretrain.registry import PARAM_SCHEDULERS
+
+
+class WeightDecaySchedulerMixin:
+ """A mixin class for learning rate schedulers."""
+
+ def __init__(self, optimizer, *args, **kwargs):
+ super().__init__(optimizer, 'weight_decay', *args, **kwargs)
+
+
+@PARAM_SCHEDULERS.register_module()
+class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin,
+ CosineAnnealingParamScheduler):
+ """Set the weight decay value of each parameter group using a cosine
+ annealing schedule.
+
+ If the weight decay was set to be 0 initially, the weight decay value will
+ be 0 constantly during the training.
+ """
+
+ def _get_value(self) -> list:
+ """Compute value using chainable form of the scheduler."""
+
+ def _get_eta_min(base_value):
+ if self.eta_min_ratio is None:
+ return self.eta_min
+ return base_value * self.eta_min_ratio
+
+ if self.last_step == 0:
+ return [
+ group[self.param_name] for group in self.optimizer.param_groups
+ ]
+ elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0:
+ weight_decay_value_list = []
+ for base_value, group in zip(self.base_values,
+ self.optimizer.param_groups):
+ if base_value == 0:
+ group_value = 0
+ else:
+ group_value = group[self.param_name] + (
+ base_value - _get_eta_min(base_value)) * (
+ 1 - math.cos(math.pi / self.T_max)) / 2
+ weight_decay_value_list.append(group_value)
+ return weight_decay_value_list
+
+ weight_decay_value_list = []
+ for base_value, group in zip(self.base_values,
+ self.optimizer.param_groups):
+ if base_value == 0:
+ group_value = 0
+ else:
+ group_value = (
+ 1 + math.cos(math.pi * self.last_step / self.T_max)) / (
+ 1 + math.cos(math.pi *
+ (self.last_step - 1) / self.T_max)
+ ) * (group[self.param_name] -
+ _get_eta_min(base_value)) + _get_eta_min(base_value)
+ weight_decay_value_list.append(group_value)
+ return weight_decay_value_list
diff --git a/mmpretrain/evaluation/__init__.py b/mmpretrain/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70dc226d30f7b8e4ee5a44ca163ad1ae04eabf5
--- /dev/null
+++ b/mmpretrain/evaluation/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .functional import * # noqa: F401,F403
+from .metrics import * # noqa: F401,F403
diff --git a/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aba2f6036747eea59a71e2f97b0d349c7efe865
Binary files /dev/null and b/mmpretrain/evaluation/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/functional/__init__.py b/mmpretrain/evaluation/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef101fec61e72abc0eb90266d453b5b22331378d
--- /dev/null
+++ b/mmpretrain/evaluation/functional/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3e015862669f76e47347e9e68707daf29dd4316
Binary files /dev/null and b/mmpretrain/evaluation/functional/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd46de7833ff5c2c5fb58a077eee8785710ff37c
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .caption import COCOCaption
+from .gqa import GQAAcc
+from .multi_label import AveragePrecision, MultiLabelMetric
+from .multi_task import MultiTasksMetric
+from .nocaps import NocapsSave
+from .retrieval import RetrievalAveragePrecision, RetrievalRecall
+from .scienceqa import ScienceQAMetric
+from .shape_bias_label import ShapeBiasMetric
+from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
+from .visual_grounding_eval import VisualGroundingMetric
+from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
+from .vqa import ReportVQA, VQAAcc
+
+__all__ = [
+ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
+ 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
+ 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
+ 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
+ 'RetrievalAveragePrecision', 'ShapeBiasMetric'
+]
diff --git a/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9db958ec9aaf01fab15a49fcb4deec9129e2e76
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a173efbe94c66d7d00aed22d2193a7f2345accde
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/caption.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90a52a6dbf1f822d8588ca36acd97d01c164b825
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/gqa.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ef486926d8c574879dc2157742fbef4f490788d
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_label.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c60058032d33be8d9a1a3dc9cd9221d1e21be4f8
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/multi_task.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ec8b683f984d02ca339b0aec9bfd63133fdd258
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/nocaps.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb731b44aae65616d9f1c859b41ec8a7063702ea
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/retrieval.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03603ff44a999a0b7d7b332d807ef82c4a87be56
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/scienceqa.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5712509141dbd4b1b7b98dfd7635fb7c30d3be74
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/shape_bias_label.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dc181c42f1c2c8be0c81f6045373096ac9d5188
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/single_label.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d47e10ac5aaf94bd0051db1d1e0773f75c27319
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/visual_grounding_eval.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b49bb1c97051c449328bae888267596ca57cc75
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/voc_multi_label.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19fc0f5e2fbef1a99af875be6553634ce79f1226
Binary files /dev/null and b/mmpretrain/evaluation/metrics/__pycache__/vqa.cpython-38.pyc differ
diff --git a/mmpretrain/evaluation/metrics/caption.py b/mmpretrain/evaluation/metrics/caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4bffabfa97a9c6faec7ecc0ffb6d9ba2f435b97
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/caption.py
@@ -0,0 +1,136 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import tempfile
+from typing import List, Optional
+
+from mmengine.evaluator import BaseMetric
+from mmengine.utils import track_iter_progress
+
+from mmpretrain.registry import METRICS
+from mmpretrain.utils import require
+
+try:
+ from pycocoevalcap.eval import COCOEvalCap
+ from pycocotools.coco import COCO
+except ImportError:
+ COCOEvalCap = None
+ COCO = None
+
+
+@METRICS.register_module()
+class COCOCaption(BaseMetric):
+ """Coco Caption evaluation wrapper.
+
+ Save the generated captions and transform into coco format.
+ Calling COCO API for caption metrics.
+
+ Args:
+ ann_file (str): the path for the COCO format caption ground truth
+ json file, load for evaluations.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+
+ @require('pycocoevalcap')
+ def __init__(self,
+ ann_file: str,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None):
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.ann_file = ann_file
+
+ def process(self, data_batch, data_samples):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+
+ for data_sample in data_samples:
+ result = dict()
+
+ result['caption'] = data_sample.get('pred_caption')
+ result['image_id'] = int(data_sample.get('image_id'))
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method.
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+
+ eval_result_file = save_result(
+ result=results,
+ result_dir=temp_dir,
+ filename='m4-caption_pred',
+ remove_duplicate='image_id',
+ )
+
+ coco_val = coco_caption_eval(eval_result_file, self.ann_file)
+
+ return coco_val
+
+
+def save_result(result, result_dir, filename, remove_duplicate=''):
+ """Saving predictions as json file for evaluation."""
+
+ # combine results from all processes
+ result_new = []
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in track_iter_progress(result):
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ final_result_file_url = os.path.join(result_dir, '%s.json' % filename)
+ print(f'result file saved to {final_result_file_url}')
+ json.dump(result, open(final_result_file_url, 'w'))
+
+ return final_result_file_url
+
+
+def coco_caption_eval(results_file, ann_file):
+ """Evaluation between gt json and prediction json files."""
+ # create coco object and coco_result object
+ coco = COCO(ann_file)
+ coco_result = coco.loadRes(results_file)
+
+ # create coco_eval object by taking coco and coco_result
+ coco_eval = COCOEvalCap(coco, coco_result)
+
+ # make sure the image ids are the same
+ coco_eval.params['image_id'] = coco_result.getImgIds()
+
+ # This will take some times at the first run
+ coco_eval.evaluate()
+
+ # print output evaluation scores
+ for metric, score in coco_eval.eval.items():
+ print(f'{metric}: {score:.3f}')
+
+ return coco_eval.eval
diff --git a/mmpretrain/evaluation/metrics/gqa.py b/mmpretrain/evaluation/metrics/gqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5e8b0725524839c5b0a15a8ba6fb4eed689e589
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/gqa.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional
+
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.evaluation.metrics.vqa import (_process_digit_article,
+ _process_punctuation)
+from mmpretrain.registry import METRICS
+
+
+@METRICS.register_module()
+class GQAAcc(BaseMetric):
+ """GQA Acc metric.
+
+ Compute GQA accuracy.
+
+ Args:
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+ default_prefix = 'GQA'
+
+ def __init__(self,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ def process(self, data_batch, data_samples) -> None:
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for sample in data_samples:
+ gt_answer = sample.get('gt_answer')
+ result = {
+ 'pred_answer': sample.get('pred_answer'),
+ 'gt_answer': gt_answer
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: List) -> dict:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ acc = []
+ for result in results:
+ pred_answer = self._process_answer(result['pred_answer'])
+ gt_answer = self._process_answer(result['gt_answer'])
+ gqa_acc = 1 if pred_answer == gt_answer else 0
+ acc.append(gqa_acc)
+
+ accuracy = sum(acc) / len(acc)
+
+ metrics = {'acc': accuracy}
+ return metrics
+
+ def _process_answer(self, answer) -> str:
+ answer = _process_punctuation(answer)
+ answer = _process_digit_article(answer)
+ return answer
diff --git a/mmpretrain/evaluation/metrics/multi_label.py b/mmpretrain/evaluation/metrics/multi_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd91aac4449c845fbed514ed5f800bd971236ade
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/multi_label.py
@@ -0,0 +1,599 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Union
+
+import numpy as np
+import torch
+from mmengine.evaluator import BaseMetric
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import METRICS
+from mmpretrain.structures import label_to_onehot
+from .single_label import _precision_recall_f1_support, to_tensor
+
+
+@METRICS.register_module()
+class MultiLabelMetric(BaseMetric):
+ r"""A collection of precision, recall, f1-score and support for
+ multi-label tasks.
+
+ The collection of metrics is for single-label multi-class classification.
+ And all these metrics are based on the confusion matrix of every category:
+
+ .. image:: ../../_static/image/confusion-matrix.png
+ :width: 60%
+ :align: center
+
+ All metrics can be formulated use variables above:
+
+ **Precision** is the fraction of correct predictions in all predictions:
+
+ .. math::
+ \text{Precision} = \frac{TP}{TP+FP}
+
+ **Recall** is the fraction of correct predictions in all targets:
+
+ .. math::
+ \text{Recall} = \frac{TP}{TP+FN}
+
+ **F1-score** is the harmonic mean of the precision and recall:
+
+ .. math::
+ \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
+
+ **Support** is the number of samples:
+
+ .. math::
+ \text{Support} = TP + TN + FN + FP
+
+ Args:
+ thr (float, optional): Predictions with scores under the threshold
+ are considered as negative. If None, the ``topk`` predictions will
+ be considered as positive. If the ``topk`` is also None, use
+ ``thr=0.5`` as default. Defaults to None.
+ topk (int, optional): Predictions with the k-th highest scores are
+ considered as positive. If None, use ``thr`` to determine positive
+ predictions. If both ``thr`` and ``topk`` are not None, use
+ ``thr``. Defaults to None.
+ items (Sequence[str]): The detailed metric items to evaluate, select
+ from "precision", "recall", "f1-score" and "support".
+ Defaults to ``('precision', 'recall', 'f1-score')``.
+ average (str | None): How to calculate the final metrics from the
+ confusion matrix of every category. It supports three modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories.
+ - `"micro"`: Average the confusion matrix over all categories and
+ calculate metrics on the mean confusion matrix.
+ - `None`: Calculate metrics of every category and output directly.
+
+ Defaults to "macro".
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.evaluation import MultiLabelMetric
+ >>> # ------ The Basic Usage for category indices labels -------
+ >>> y_pred = [[0], [1], [0, 1], [3]]
+ >>> y_true = [[0, 3], [0, 2], [1], [3]]
+ >>> # Output precision, recall, f1-score and support
+ >>> MultiLabelMetric.calculate(
+ ... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4)
+ (tensor(50.), tensor(50.), tensor(45.8333), tensor(6))
+ >>> # ----------- The Basic Usage for one-hot labels -----------
+ >>> y_pred = torch.tensor([[1, 1, 0, 0],
+ ... [1, 1, 0, 0],
+ ... [0, 0, 1, 0],
+ ... [0, 1, 0, 0],
+ ... [0, 1, 0, 0]])
+ >>> y_true = torch.Tensor([[1, 1, 0, 0],
+ ... [0, 0, 1, 0],
+ ... [1, 1, 1, 0],
+ ... [1, 0, 0, 0],
+ ... [1, 0, 0, 0]])
+ >>> MultiLabelMetric.calculate(y_pred, y_true)
+ (tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8))
+ >>> # --------- The Basic Usage for one-hot pred scores ---------
+ >>> y_pred = torch.rand(y_true.size())
+ >>> y_pred
+ tensor([[0.4575, 0.7335, 0.3934, 0.2572],
+ [0.1318, 0.1004, 0.8248, 0.6448],
+ [0.8349, 0.6294, 0.7896, 0.2061],
+ [0.4037, 0.7308, 0.6713, 0.8374],
+ [0.3779, 0.4836, 0.0313, 0.0067]])
+ >>> # Calculate with different threshold.
+ >>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1)
+ (tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8))
+ >>> # Calculate with topk.
+ >>> MultiLabelMetric.calculate(y_pred, y_true, topk=1)
+ (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
+ >>>
+ >>> # ------------------- Use with Evalutor -------------------
+ >>> from mmpretrain.structures import DataSample
+ >>> from mmengine.evaluator import Evaluator
+ >>> data_sampels = [
+ ... DataSample().set_pred_score(pred).set_gt_score(gt)
+ ... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
+ >>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
+ >>> evaluator.process(data_sampels)
+ >>> evaluator.evaluate(1000)
+ {
+ 'multi-label/precision': 50.72898037055408,
+ 'multi-label/recall': 50.06836461357571,
+ 'multi-label/f1-score': 50.384466955258475
+ }
+ >>> # Evaluate on each class by using topk strategy
+ >>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None))
+ >>> evaluator.process(data_sampels)
+ >>> evaluator.evaluate(1000)
+ {
+ 'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5],
+ 'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27],
+ 'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25]
+ }
+ """ # noqa: E501
+ default_prefix: Optional[str] = 'multi-label'
+
+ def __init__(self,
+ thr: Optional[float] = None,
+ topk: Optional[int] = None,
+ items: Sequence[str] = ('precision', 'recall', 'f1-score'),
+ average: Optional[str] = 'macro',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+
+ logger = MMLogger.get_current_instance()
+ if thr is None and topk is None:
+ thr = 0.5
+ logger.warning('Neither thr nor k is given, set thr as 0.5 by '
+ 'default.')
+ elif thr is not None and topk is not None:
+ logger.warning('Both thr and topk are given, '
+ 'use threshold in favor of top-k.')
+
+ self.thr = thr
+ self.topk = topk
+ self.average = average
+
+ for item in items:
+ assert item in ['precision', 'recall', 'f1-score', 'support'], \
+ f'The metric {item} is not supported by `SingleLabelMetric`,' \
+ ' please choose from "precision", "recall", "f1-score" and ' \
+ '"support".'
+ self.items = tuple(items)
+
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ result = dict()
+
+ result['pred_score'] = data_sample['pred_score'].clone()
+ num_classes = result['pred_score'].size()[-1]
+
+ if 'gt_score' in data_sample:
+ result['gt_score'] = data_sample['gt_score'].clone()
+ else:
+ result['gt_score'] = label_to_onehot(data_sample['gt_label'],
+ num_classes)
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method. `self.results`
+ # are a list of results from multiple batch, while the input `results`
+ # are the collected results.
+ metrics = {}
+
+ target = torch.stack([res['gt_score'] for res in results])
+ pred = torch.stack([res['pred_score'] for res in results])
+
+ metric_res = self.calculate(
+ pred,
+ target,
+ pred_indices=False,
+ target_indices=False,
+ average=self.average,
+ thr=self.thr,
+ topk=self.topk)
+
+ def pack_results(precision, recall, f1_score, support):
+ single_metrics = {}
+ if 'precision' in self.items:
+ single_metrics['precision'] = precision
+ if 'recall' in self.items:
+ single_metrics['recall'] = recall
+ if 'f1-score' in self.items:
+ single_metrics['f1-score'] = f1_score
+ if 'support' in self.items:
+ single_metrics['support'] = support
+ return single_metrics
+
+ if self.thr:
+ suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}'
+ for k, v in pack_results(*metric_res).items():
+ metrics[k + suffix] = v
+ else:
+ for k, v in pack_results(*metric_res).items():
+ metrics[k + f'_top{self.topk}'] = v
+
+ result_metrics = dict()
+ for k, v in metrics.items():
+ if self.average is None:
+ result_metrics[k + '_classwise'] = v.detach().cpu().tolist()
+ elif self.average == 'macro':
+ result_metrics[k] = v.item()
+ else:
+ result_metrics[k + f'_{self.average}'] = v.item()
+ return result_metrics
+
+ @staticmethod
+ def calculate(
+ pred: Union[torch.Tensor, np.ndarray, Sequence],
+ target: Union[torch.Tensor, np.ndarray, Sequence],
+ pred_indices: bool = False,
+ target_indices: bool = False,
+ average: Optional[str] = 'macro',
+ thr: Optional[float] = None,
+ topk: Optional[int] = None,
+ num_classes: Optional[int] = None
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Calculate the precision, recall, f1-score.
+
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, num_classes)`` or a sequence of index/onehot
+ format labels.
+ target (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, num_classes)`` or a sequence of index/onehot
+ format labels.
+ pred_indices (bool): Whether the ``pred`` is a sequence of
+ category index labels. If True, ``num_classes`` must be set.
+ Defaults to False.
+ target_indices (bool): Whether the ``target`` is a sequence of
+ category index labels. If True, ``num_classes`` must be set.
+ Defaults to False.
+ average (str | None): How to calculate the final metrics from
+ the confusion matrix of every category. It supports three
+ modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories.
+ - `"micro"`: Average the confusion matrix over all categories
+ and calculate metrics on the mean confusion matrix.
+ - `None`: Calculate metrics of every category and output
+ directly.
+
+ Defaults to "macro".
+ thr (float, optional): Predictions with scores under the thresholds
+ are considered as negative. Defaults to None.
+ topk (int, optional): Predictions with the k-th highest scores are
+ considered as positive. Defaults to None.
+ num_classes (Optional, int): The number of classes. If the ``pred``
+ is indices instead of onehot, this argument is required.
+ Defaults to None.
+
+ Returns:
+ Tuple: The tuple contains precision, recall and f1-score.
+ And the type of each item is:
+
+ - torch.Tensor: A tensor for each metric. The shape is (1, ) if
+ ``average`` is not None, and (C, ) if ``average`` is None.
+
+ Notes:
+ If both ``thr`` and ``topk`` are set, use ``thr` to determine
+ positive predictions. If neither is set, use ``thr=0.5`` as
+ default.
+ """
+ average_options = ['micro', 'macro', None]
+ assert average in average_options, 'Invalid `average` argument, ' \
+ f'please specicy from {average_options}.'
+
+ def _format_label(label, is_indices):
+ """format various label to torch.Tensor."""
+ if isinstance(label, np.ndarray):
+ assert label.ndim == 2, 'The shape `pred` and `target` ' \
+ 'array must be (N, num_classes).'
+ label = torch.from_numpy(label)
+ elif isinstance(label, torch.Tensor):
+ assert label.ndim == 2, 'The shape `pred` and `target` ' \
+ 'tensor must be (N, num_classes).'
+ elif isinstance(label, Sequence):
+ if is_indices:
+ assert num_classes is not None, 'For index-type labels, ' \
+ 'please specify `num_classes`.'
+ label = torch.stack([
+ label_to_onehot(indices, num_classes)
+ for indices in label
+ ])
+ else:
+ label = torch.stack(
+ [to_tensor(onehot) for onehot in label])
+ else:
+ raise TypeError(
+ 'The `pred` and `target` must be type of torch.tensor or '
+ f'np.ndarray or sequence but get {type(label)}.')
+ return label
+
+ pred = _format_label(pred, pred_indices)
+ target = _format_label(target, target_indices).long()
+
+ assert pred.shape == target.shape, \
+ f"The size of pred ({pred.shape}) doesn't match "\
+ f'the target ({target.shape}).'
+
+ if num_classes is not None:
+ assert pred.size(1) == num_classes, \
+ f'The shape of `pred` ({pred.shape}) '\
+ f"doesn't match the num_classes ({num_classes})."
+ num_classes = pred.size(1)
+
+ thr = 0.5 if (thr is None and topk is None) else thr
+
+ if thr is not None:
+ # a label is predicted positive if larger than thr
+ pos_inds = (pred >= thr).long()
+ else:
+ # top-k labels will be predicted positive for any example
+ _, topk_indices = pred.topk(topk)
+ pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1)
+ pos_inds = pos_inds.long()
+
+ return _precision_recall_f1_support(pos_inds, target, average)
+
+
+def _average_precision(pred: torch.Tensor,
+ target: torch.Tensor) -> torch.Tensor:
+ r"""Calculate the average precision for a single class.
+
+ AP summarizes a precision-recall curve as the weighted mean of maximum
+ precisions obtained for any r'>r, where r is the recall:
+
+ .. math::
+ \text{AP} = \sum_n (R_n - R_{n-1}) P_n
+
+ Note that no approximation is involved since the curve is piecewise
+ constant.
+
+ Args:
+ pred (torch.Tensor): The model prediction with shape
+ ``(N, num_classes)``.
+ target (torch.Tensor): The target of predictions with shape
+ ``(N, num_classes)``.
+
+ Returns:
+ torch.Tensor: average precision result.
+ """
+ assert pred.shape == target.shape, \
+ f"The size of pred ({pred.shape}) doesn't match "\
+ f'the target ({target.shape}).'
+
+ # a small value for division by zero errors
+ eps = torch.finfo(torch.float32).eps
+
+ # get rid of -1 target such as difficult sample
+ # that is not wanted in evaluation results.
+ valid_index = target > -1
+ pred = pred[valid_index]
+ target = target[valid_index]
+
+ # sort examples
+ sorted_pred_inds = torch.argsort(pred, dim=0, descending=True)
+ sorted_target = target[sorted_pred_inds]
+
+ # get indexes when gt_true is positive
+ pos_inds = sorted_target == 1
+
+ # Calculate cumulative tp case numbers
+ tps = torch.cumsum(pos_inds, 0)
+ total_pos = tps[-1].item() # the last of tensor may change later
+
+ # Calculate cumulative tp&fp(pred_poss) case numbers
+ pred_pos_nums = torch.arange(1, len(sorted_target) + 1).to(pred.device)
+ pred_pos_nums[pred_pos_nums < eps] = eps
+
+ tps[torch.logical_not(pos_inds)] = 0
+ precision = tps / pred_pos_nums.float()
+ ap = torch.sum(precision, 0) / max(total_pos, eps)
+ return ap
+
+
+@METRICS.register_module()
+class AveragePrecision(BaseMetric):
+ r"""Calculate the average precision with respect of classes.
+
+ AveragePrecision (AP) summarizes a precision-recall curve as the weighted
+ mean of maximum precisions obtained for any r'>r, where r is the recall:
+
+ .. math::
+ \text{AP} = \sum_n (R_n - R_{n-1}) P_n
+
+ Note that no approximation is involved since the curve is piecewise
+ constant.
+
+ Args:
+ average (str | None): How to calculate the final metrics from
+ every category. It supports two modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories. The result of this mode
+ is also called **mAP**.
+ - `None`: Calculate metrics of every category and output directly.
+
+ Defaults to "macro".
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ References
+ ----------
+ 1. `Wikipedia entry for the Average precision
+ `_
+
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.evaluation import AveragePrecision
+ >>> # --------- The Basic Usage for one-hot pred scores ---------
+ >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2],
+ ... [0.1, 0.2, 0.2, 0.1],
+ ... [0.7, 0.5, 0.9, 0.3],
+ ... [0.8, 0.1, 0.1, 0.2]])
+ >>> y_true = torch.Tensor([[1, 1, 0, 0],
+ ... [0, 1, 0, 0],
+ ... [0, 0, 1, 0],
+ ... [1, 0, 0, 0]])
+ >>> AveragePrecision.calculate(y_pred, y_true)
+ tensor(70.833)
+ >>> # ------------------- Use with Evalutor -------------------
+ >>> from mmpretrain.structures import DataSample
+ >>> from mmengine.evaluator import Evaluator
+ >>> data_samples = [
+ ... DataSample().set_pred_score(i).set_gt_score(j)
+ ... for i, j in zip(y_pred, y_true)
+ ... ]
+ >>> evaluator = Evaluator(metrics=AveragePrecision())
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(5)
+ {'multi-label/mAP': 70.83333587646484}
+ >>> # Evaluate on each class
+ >>> evaluator = Evaluator(metrics=AveragePrecision(average=None))
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(5)
+ {'multi-label/AP_classwise': [100., 83.33, 100., 0.]}
+ """
+ default_prefix: Optional[str] = 'multi-label'
+
+ def __init__(self,
+ average: Optional[str] = 'macro',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.average = average
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+
+ for data_sample in data_samples:
+ result = dict()
+
+ result['pred_score'] = data_sample['pred_score'].clone()
+ num_classes = result['pred_score'].size()[-1]
+
+ if 'gt_score' in data_sample:
+ result['gt_score'] = data_sample['gt_score'].clone()
+ else:
+ result['gt_score'] = label_to_onehot(data_sample['gt_label'],
+ num_classes)
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method. `self.results`
+ # are a list of results from multiple batch, while the input `results`
+ # are the collected results.
+
+ # concat
+ target = torch.stack([res['gt_score'] for res in results])
+ pred = torch.stack([res['pred_score'] for res in results])
+
+ ap = self.calculate(pred, target, self.average)
+
+ result_metrics = dict()
+
+ if self.average is None:
+ result_metrics['AP_classwise'] = ap.detach().cpu().tolist()
+ else:
+ result_metrics['mAP'] = ap.item()
+
+ return result_metrics
+
+ @staticmethod
+ def calculate(pred: Union[torch.Tensor, np.ndarray],
+ target: Union[torch.Tensor, np.ndarray],
+ average: Optional[str] = 'macro') -> torch.Tensor:
+ r"""Calculate the average precision for a single class.
+
+ Args:
+ pred (torch.Tensor | np.ndarray): The model predictions with
+ shape ``(N, num_classes)``.
+ target (torch.Tensor | np.ndarray): The target of predictions
+ with shape ``(N, num_classes)``.
+ average (str | None): The average method. It supports two modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories. The result of this mode
+ is also called mAP.
+ - `None`: Calculate metrics of every category and output
+ directly.
+
+ Defaults to "macro".
+
+ Returns:
+ torch.Tensor: the average precision of all classes.
+ """
+ average_options = ['macro', None]
+ assert average in average_options, 'Invalid `average` argument, ' \
+ f'please specicy from {average_options}.'
+
+ pred = to_tensor(pred)
+ target = to_tensor(target)
+ assert pred.ndim == 2 and pred.shape == target.shape, \
+ 'Both `pred` and `target` should have shape `(N, num_classes)`.'
+
+ num_classes = pred.shape[1]
+ ap = pred.new_zeros(num_classes)
+ for k in range(num_classes):
+ ap[k] = _average_precision(pred[:, k], target[:, k])
+ if average == 'macro':
+ return ap.mean() * 100.0
+ else:
+ return ap * 100
diff --git a/mmpretrain/evaluation/metrics/multi_task.py b/mmpretrain/evaluation/metrics/multi_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e6af7680192883308df5f24b65ec38c9bb65ce6
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/multi_task.py
@@ -0,0 +1,120 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Sequence
+
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.registry import METRICS
+
+
+@METRICS.register_module()
+class MultiTasksMetric(BaseMetric):
+ """Metrics for MultiTask
+ Args:
+ task_metrics(dict): a dictionary in the keys are the names of the tasks
+ and the values is a list of the metric corresponds to this task
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.evaluation import MultiTasksMetric
+ # -------------------- The Basic Usage --------------------
+ >>>task_metrics = {
+ 'task0': [dict(type='Accuracy', topk=(1, ))],
+ 'task1': [dict(type='Accuracy', topk=(1, 3))]
+ }
+ >>>pred = [{
+ 'pred_task': {
+ 'task0': torch.tensor([0.7, 0.0, 0.3]),
+ 'task1': torch.tensor([0.5, 0.2, 0.3])
+ },
+ 'gt_task': {
+ 'task0': torch.tensor(0),
+ 'task1': torch.tensor(2)
+ }
+ }, {
+ 'pred_task': {
+ 'task0': torch.tensor([0.0, 0.0, 1.0]),
+ 'task1': torch.tensor([0.0, 0.0, 1.0])
+ },
+ 'gt_task': {
+ 'task0': torch.tensor(2),
+ 'task1': torch.tensor(2)
+ }
+ }]
+ >>>metric = MultiTasksMetric(task_metrics)
+ >>>metric.process(None, pred)
+ >>>results = metric.evaluate(2)
+ results = {
+ 'task0_accuracy/top1': 100.0,
+ 'task1_accuracy/top1': 50.0,
+ 'task1_accuracy/top3': 100.0
+ }
+ """
+
+ def __init__(self,
+ task_metrics: Dict,
+ collect_device: str = 'cpu') -> None:
+ self.task_metrics = task_metrics
+ super().__init__(collect_device=collect_device)
+
+ self._metrics = {}
+ for task_name in self.task_metrics.keys():
+ self._metrics[task_name] = []
+ for metric in self.task_metrics[task_name]:
+ self._metrics[task_name].append(METRICS.build(metric))
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for task_name in self.task_metrics.keys():
+ filtered_data_samples = []
+ for data_sample in data_samples:
+ eval_mask = data_sample[task_name]['eval_mask']
+ if eval_mask:
+ filtered_data_samples.append(data_sample[task_name])
+ for metric in self._metrics[task_name]:
+ metric.process(data_batch, filtered_data_samples)
+
+ def compute_metrics(self, results: list) -> dict:
+ raise NotImplementedError(
+ 'compute metrics should not be used here directly')
+
+ def evaluate(self, size):
+ """Evaluate the model performance of the whole dataset after processing
+ all batches.
+
+ Args:
+ size (int): Length of the entire validation dataset. When batch
+ size > 1, the dataloader may pad some data samples to make
+ sure all ranks have the same length of dataset slice. The
+ ``collect_results`` function will drop the padded data based on
+ this size.
+ Returns:
+ dict: Evaluation metrics dict on the val dataset. The keys are
+ "{task_name}_{metric_name}" , and the values
+ are corresponding results.
+ """
+ metrics = {}
+ for task_name in self._metrics:
+ for metric in self._metrics[task_name]:
+ name = metric.__class__.__name__
+ if name == 'MultiTasksMetric' or metric.results:
+ results = metric.evaluate(size)
+ else:
+ results = {metric.__class__.__name__: 0}
+ for key in results:
+ name = f'{task_name}_{key}'
+ if name in results:
+ """Inspired from https://github.com/open-
+ mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2
+ 747172e/mmengine/evalua tor/evaluator.py#L84-L87."""
+ raise ValueError(
+ 'There are multiple metric results with the same'
+ f'metric name {name}. Please make sure all metrics'
+ 'have different prefixes.')
+ metrics[name] = results[key]
+ return metrics
diff --git a/mmpretrain/evaluation/metrics/nocaps.py b/mmpretrain/evaluation/metrics/nocaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8e1d0625b66dfa1abe59bd6f83ea2a6c0b3d446
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/nocaps.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional
+
+import mmengine
+
+from mmpretrain.registry import METRICS
+from mmpretrain.utils import require
+from .caption import COCOCaption, save_result
+
+try:
+ from pycocoevalcap.eval import COCOEvalCap
+ from pycocotools.coco import COCO
+except ImportError:
+ COCOEvalCap = None
+ COCO = None
+
+
+@METRICS.register_module()
+class NocapsSave(COCOCaption):
+ """Nocaps evaluation wrapper.
+
+ Save the generated captions and transform into coco format.
+ The dumped file can be submitted to the official evluation system.
+
+ Args:
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+
+ @require('pycocoevalcap')
+ def __init__(self,
+ save_dir: str = './',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None):
+ super(COCOCaption, self).__init__(
+ collect_device=collect_device, prefix=prefix)
+ self.save_dir = save_dir
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+ """
+ mmengine.mkdir_or_exist(self.save_dir)
+ save_result(
+ result=results,
+ result_dir=self.save_dir,
+ filename='nocap_pred',
+ remove_duplicate='image_id',
+ )
+
+ return dict()
diff --git a/mmpretrain/evaluation/metrics/retrieval.py b/mmpretrain/evaluation/metrics/retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..9813486b521c5b73d7be96901ea4f604bbe2a938
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/retrieval.py
@@ -0,0 +1,445 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Union
+
+import mmengine
+import numpy as np
+import torch
+from mmengine.evaluator import BaseMetric
+from mmengine.utils import is_seq_of
+
+from mmpretrain.registry import METRICS
+from mmpretrain.structures import label_to_onehot
+from .single_label import to_tensor
+
+
+@METRICS.register_module()
+class RetrievalRecall(BaseMetric):
+ r"""Recall evaluation metric for image retrieval.
+
+ Args:
+ topk (int | Sequence[int]): If the ground truth label matches one of
+ the best **k** predictions, the sample will be regard as a positive
+ prediction. If the parameter is a tuple, all of top-k recall will
+ be calculated and outputted together. Defaults to 1.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Examples:
+ Use in the code:
+
+ >>> import torch
+ >>> from mmpretrain.evaluation import RetrievalRecall
+ >>> # -------------------- The Basic Usage --------------------
+ >>> y_pred = [[0], [1], [2], [3]]
+ >>> y_true = [[0, 1], [2], [1], [0, 3]]
+ >>> RetrievalRecall.calculate(
+ >>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
+ [tensor([50.])]
+ >>> # Calculate the recall@1 and recall@5 for non-indices input.
+ >>> y_score = torch.rand((1000, 10))
+ >>> import torch.nn.functional as F
+ >>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
+ >>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
+ [tensor(9.3000), tensor(48.4000)]
+ >>>
+ >>> # ------------------- Use with Evalutor -------------------
+ >>> from mmpretrain.structures import DataSample
+ >>> from mmengine.evaluator import Evaluator
+ >>> data_samples = [
+ ... DataSample().set_gt_label([0, 1]).set_pred_score(
+ ... torch.rand(10))
+ ... for i in range(1000)
+ ... ]
+ >>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(1000)
+ {'retrieval/Recall@1': 20.700000762939453,
+ 'retrieval/Recall@5': 78.5999984741211}
+
+ Use in OpenMMLab configs:
+
+ .. code:: python
+
+ val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
+ test_evaluator = val_evaluator
+ """
+ default_prefix: Optional[str] = 'retrieval'
+
+ def __init__(self,
+ topk: Union[int, Sequence[int]],
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ topk = (topk, ) if isinstance(topk, int) else topk
+
+ for k in topk:
+ if k <= 0:
+ raise ValueError('`topk` must be a ingter larger than 0 '
+ 'or seq of ingter larger than 0.')
+
+ self.topk = topk
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]):
+ """Process one batch of data and predictions.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch (Sequence[dict]): A batch of data from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ pred_score = data_sample['pred_score'].clone()
+ gt_label = data_sample['gt_label']
+
+ if 'gt_score' in data_sample:
+ target = data_sample.get('gt_score').clone()
+ else:
+ num_classes = pred_score.size()[-1]
+ target = label_to_onehot(gt_label, num_classes)
+
+ # Because the retrieval output logit vector will be much larger
+ # compared to the normal classification, to save resources, the
+ # evaluation results are computed each batch here and then reduce
+ # all results at the end.
+ result = RetrievalRecall.calculate(
+ pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ result_metrics = dict()
+ for i, k in enumerate(self.topk):
+ recall_at_k = sum([r[i].item() for r in results]) / len(results)
+ result_metrics[f'Recall@{k}'] = recall_at_k
+
+ return result_metrics
+
+ @staticmethod
+ def calculate(pred: Union[np.ndarray, torch.Tensor],
+ target: Union[np.ndarray, torch.Tensor],
+ topk: Union[int, Sequence[int]],
+ pred_indices: (bool) = False,
+ target_indices: (bool) = False) -> float:
+ """Calculate the average recall.
+
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, M)`` or a sequence of index/onehot
+ format labels.
+ target (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, M)`` or a sequence of index/onehot
+ format labels.
+ topk (int, Sequence[int]): Predictions with the k-th highest
+ scores are considered as positive.
+ pred_indices (bool): Whether the ``pred`` is a sequence of
+ category index labels. Defaults to False.
+ target_indices (bool): Whether the ``target`` is a sequence of
+ category index labels. Defaults to False.
+
+ Returns:
+ List[float]: the average recalls.
+ """
+ topk = (topk, ) if isinstance(topk, int) else topk
+ for k in topk:
+ if k <= 0:
+ raise ValueError('`topk` must be a ingter larger than 0 '
+ 'or seq of ingter larger than 0.')
+
+ max_keep = max(topk)
+ pred = _format_pred(pred, max_keep, pred_indices)
+ target = _format_target(target, target_indices)
+
+ assert len(pred) == len(target), (
+ f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
+ f'must be the same.')
+
+ num_samples = len(pred)
+ results = []
+ for k in topk:
+ recalls = torch.zeros(num_samples)
+ for i, (sample_pred,
+ sample_target) in enumerate(zip(pred, target)):
+ sample_pred = np.array(to_tensor(sample_pred).cpu())
+ sample_target = np.array(to_tensor(sample_target).cpu())
+ recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max())
+ results.append(recalls.mean() * 100)
+ return results
+
+
+@METRICS.register_module()
+class RetrievalAveragePrecision(BaseMetric):
+ r"""Calculate the average precision for image retrieval.
+
+ Args:
+ topk (int, optional): Predictions with the k-th highest scores are
+ considered as positive.
+ mode (str, optional): The mode to calculate AP, choose from
+ 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Note:
+ If the ``mode`` set to 'IR', use the stanford AP calculation of
+ information retrieval as in wikipedia page[1]; if set to 'integrate',
+ the method implemented integrates over the precision-recall curve
+ by averaging two adjacent precision points, then multiplying by the
+ recall step like mAP in Detection task. This is the convention for
+ the Revisited Oxford/Paris datasets[2].
+
+ References:
+ [1] `Wikipedia entry for the Average precision `_
+
+ [2] `The Oxford Buildings Dataset
+ `_
+
+ Examples:
+ Use in code:
+
+ >>> import torch
+ >>> import numpy as np
+ >>> from mmcls.evaluation import RetrievalAveragePrecision
+ >>> # using index format inputs
+ >>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
+ >>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
+ >>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
+ 29.246031746031747
+ >>> # using tensor format inputs
+ >>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
+ >>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
+ >>> RetrievalAveragePrecision.calculate(pred, target, 10)
+ 62.222222222222214
+
+ Use in OpenMMLab config files:
+
+ .. code:: python
+
+ val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
+ test_evaluator = val_evaluator
+ """
+
+ default_prefix: Optional[str] = 'retrieval'
+
+ def __init__(self,
+ topk: Optional[int] = None,
+ mode: Optional[str] = 'IR',
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ if topk is None or (isinstance(topk, int) and topk <= 0):
+ raise ValueError('`topk` must be a ingter larger than 0.')
+
+ mode_options = ['IR', 'integrate']
+ assert mode in mode_options, \
+ f'Invalid `mode` argument, please specify from {mode_options}.'
+
+ self.topk = topk
+ self.mode = mode
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ def process(self, data_batch: Sequence[dict],
+ data_samples: Sequence[dict]):
+ """Process one batch of data and predictions.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+ Args:
+ data_batch (Sequence[dict]): A batch of data from the dataloader.
+ predictions (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ pred_score = data_sample.get('pred_score').clone()
+
+ if 'gt_score' in data_sample:
+ target = data_sample.get('gt_score').clone()
+ else:
+ gt_label = data_sample.get('gt_label')
+ num_classes = pred_score.size()[-1]
+ target = label_to_onehot(gt_label, num_classes)
+
+ # Because the retrieval output logit vector will be much larger
+ # compared to the normal classification, to save resources, the
+ # evaluation results are computed each batch here and then reduce
+ # all results at the end.
+ result = RetrievalAveragePrecision.calculate(
+ pred_score.unsqueeze(0),
+ target.unsqueeze(0),
+ self.topk,
+ mode=self.mode)
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ result_metrics = dict()
+ result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item()
+
+ return result_metrics
+
+ @staticmethod
+ def calculate(pred: Union[np.ndarray, torch.Tensor],
+ target: Union[np.ndarray, torch.Tensor],
+ topk: Optional[int] = None,
+ pred_indices: (bool) = False,
+ target_indices: (bool) = False,
+ mode: str = 'IR') -> float:
+ """Calculate the average precision.
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, M)`` or a sequence of index/onehot
+ format labels.
+ target (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
+ shape ``(N, M)`` or a sequence of index/onehot
+ format labels.
+ topk (int, optional): Predictions with the k-th highest scores
+ are considered as positive.
+ pred_indices (bool): Whether the ``pred`` is a sequence of
+ category index labels. Defaults to False.
+ target_indices (bool): Whether the ``target`` is a sequence of
+ category index labels. Defaults to False.
+ mode (Optional[str]): The mode to calculate AP, choose from
+ 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
+
+ Note:
+ If the ``mode`` set to 'IR', use the stanford AP calculation of
+ information retrieval as in wikipedia page; if set to 'integrate',
+ the method implemented integrates over the precision-recall curve
+ by averaging two adjacent precision points, then multiplying by the
+ recall step like mAP in Detection task. This is the convention for
+ the Revisited Oxford/Paris datasets.
+
+ Returns:
+ float: the average precision of the query image.
+
+ References:
+ [1] `Wikipedia entry for Average precision(information_retrieval)
+ `_
+ [2] `The Oxford Buildings Dataset 0 else 1
+ cur_precision = (i + 1) / (rank + 1)
+ prediction = (old_precision + cur_precision) / 2
+ ap += prediction
+ ap = ap / len(target)
+
+ return ap * 100
+
+
+def _format_pred(label, topk=None, is_indices=False):
+ """format various label to List[indices]."""
+ if is_indices:
+ assert isinstance(label, Sequence), \
+ '`pred` must be Sequence of indices when' \
+ f' `pred_indices` set to True, but get {type(label)}'
+ for i, sample_pred in enumerate(label):
+ assert is_seq_of(sample_pred, int) or isinstance(
+ sample_pred, (np.ndarray, torch.Tensor)), \
+ '`pred` should be Sequence of indices when `pred_indices`' \
+ f'set to True. but pred[{i}] is {sample_pred}'
+ if topk:
+ label[i] = sample_pred[:min(topk, len(sample_pred))]
+ return label
+ if isinstance(label, np.ndarray):
+ label = torch.from_numpy(label)
+ elif not isinstance(label, torch.Tensor):
+ raise TypeError(f'The pred must be type of torch.tensor, '
+ f'np.ndarray or Sequence but get {type(label)}.')
+ topk = topk if topk else label.size()[-1]
+ _, indices = label.topk(topk)
+ return indices
+
+
+def _format_target(label, is_indices=False):
+ """format various label to List[indices]."""
+ if is_indices:
+ assert isinstance(label, Sequence), \
+ '`target` must be Sequence of indices when' \
+ f' `target_indices` set to True, but get {type(label)}'
+ for i, sample_gt in enumerate(label):
+ assert is_seq_of(sample_gt, int) or isinstance(
+ sample_gt, (np.ndarray, torch.Tensor)), \
+ '`target` should be Sequence of indices when ' \
+ f'`target_indices` set to True. but target[{i}] is {sample_gt}'
+ return label
+
+ if isinstance(label, np.ndarray):
+ label = torch.from_numpy(label)
+ elif isinstance(label, Sequence) and not mmengine.is_str(label):
+ label = torch.tensor(label)
+ elif not isinstance(label, torch.Tensor):
+ raise TypeError(f'The pred must be type of torch.tensor, '
+ f'np.ndarray or Sequence but get {type(label)}.')
+
+ indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label]
+ return indices
diff --git a/mmpretrain/evaluation/metrics/scienceqa.py b/mmpretrain/evaluation/metrics/scienceqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf01c78cc88e5ce5e232fe837a0d77293386112
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/scienceqa.py
@@ -0,0 +1,170 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from typing import List, Optional
+
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.registry import METRICS
+
+
+def get_pred_idx(prediction: str, choices: List[str],
+ options: List[str]) -> int: # noqa
+ """Get the index (e.g. 2) from the prediction (e.g. 'C')
+
+ Args:
+ prediction (str): The prediction from the model,
+ from ['A', 'B', 'C', 'D', 'E']
+ choices (List(str)): The choices for the question,
+ from ['A', 'B', 'C', 'D', 'E']
+ options (List(str)): The options for the question,
+ from ['A', 'B', 'C', 'D', 'E']
+
+ Returns:
+ int: The index of the prediction, from [0, 1, 2, 3, 4]
+ """
+ if prediction in options[:len(choices)]:
+ return options.index(prediction)
+ else:
+ return random.choice(range(len(choices)))
+
+
+@METRICS.register_module()
+class ScienceQAMetric(BaseMetric):
+ """Evaluation Metric for ScienceQA.
+
+ Args:
+ options (List(str)): Options for each question. Defaults to
+ ["A", "B", "C", "D", "E"].
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+
+ def __init__(self,
+ options: List[str] = ['A', 'B', 'C', 'D', 'E'],
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.options = options
+
+ def process(self, data_batch, data_samples) -> None:
+ """Process one batch of data samples.
+
+ data_samples should contain the following keys:
+ 1. pred_answer (str): The prediction from the model,
+ from ['A', 'B', 'C', 'D', 'E']
+ 2. choices (List(str)): The choices for the question,
+ from ['A', 'B', 'C', 'D', 'E']
+ 3. grade (int): The grade for the question, from grade1 to grade12
+ 4. subject (str): The subject for the question, from
+ ['natural science', 'social science', 'language science']
+ 5. answer (str): The answer for the question, from
+ ['A', 'B', 'C', 'D', 'E']
+ 6. hint (str): The hint for the question
+ 7. has_image (bool): Whether or not the question has image
+
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ result = dict()
+ choices = data_sample.get('choices')
+ result['prediction'] = get_pred_idx(
+ data_sample.get('pred_answer'), choices, self.options)
+ result['grade'] = data_sample.get('grade')
+ result['subject'] = data_sample.get('subject')
+ result['answer'] = data_sample.get('gt_answer')
+ hint = data_sample.get('hint')
+ has_image = data_sample.get('has_image', False)
+ result['no_context'] = True if not has_image and len(
+ hint) == 0 else False # noqa
+ result['has_text'] = True if len(hint) > 0 else False
+ result['has_image'] = has_image
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List) -> dict:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method.
+ metrics = dict()
+
+ all_acc = []
+ acc_natural = []
+ acc_social = []
+ acc_language = []
+ acc_has_text = []
+ acc_has_image = []
+ acc_no_context = []
+ acc_grade_1_6 = []
+ acc_grade_7_12 = []
+
+ for result in results:
+ correct = result['prediction'] == result['answer']
+ all_acc.append(correct)
+ # different subjects
+ if result['subject'] == 'natural science':
+ acc_natural.append(correct)
+ elif result['subject'] == 'social science':
+ acc_social.append(correct)
+ elif result['subject'] == 'language science':
+ acc_language.append(correct)
+
+ # different context
+ if result['has_text']:
+ acc_has_text.append(correct)
+ elif result['has_image']:
+ acc_has_image.append(correct)
+ elif result['no_context']:
+ acc_no_context.append(correct)
+
+ # different grade
+ if result['grade'] in [
+ 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6'
+ ]:
+ acc_grade_1_6.append(correct)
+ elif result['grade'] in [
+ 'grade7', 'grade8', 'grade9', 'grade10', 'grade11',
+ 'grade12'
+ ]:
+ acc_grade_7_12.append(correct)
+
+ metrics['all_acc'] = sum(all_acc) / len(all_acc)
+ if len(acc_natural) > 0:
+ metrics['acc_natural'] = sum(acc_natural) / len(acc_natural)
+ if len(acc_social) > 0:
+ metrics['acc_social'] = sum(acc_social) / len(acc_social)
+ if len(acc_language) > 0:
+ metrics['acc_language'] = sum(acc_language) / len(acc_language)
+ if len(acc_has_text) > 0:
+ metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text)
+ if len(acc_has_image) > 0:
+ metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image)
+ if len(acc_no_context) > 0:
+ metrics['acc_no_context'] = sum(acc_no_context) / len(
+ acc_no_context)
+ if len(acc_grade_1_6) > 0:
+ metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6)
+ if len(acc_grade_7_12) > 0:
+ metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len(
+ acc_grade_7_12)
+
+ return metrics
diff --git a/mmpretrain/evaluation/metrics/shape_bias_label.py b/mmpretrain/evaluation/metrics/shape_bias_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..27c80a36073a9e6edd5e6583e213ed93374b165e
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/shape_bias_label.py
@@ -0,0 +1,172 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import csv
+import os
+import os.path as osp
+from typing import List, Sequence
+
+import numpy as np
+import torch
+from mmengine.dist.utils import get_rank
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.registry import METRICS
+
+
+@METRICS.register_module()
+class ShapeBiasMetric(BaseMetric):
+ """Evaluate the model on ``cue_conflict`` dataset.
+
+ This module will evaluate the model on an OOD dataset, cue_conflict, in
+ order to measure the shape bias of the model. In addition to compuate the
+ Top-1 accuracy, this module also generate a csv file to record the
+ detailed prediction results, such that this csv file can be used to
+ generate the shape bias curve.
+
+ Args:
+ csv_dir (str): The directory to save the csv file.
+ model_name (str): The name of the csv file. Please note that the
+ model name should be an unique identifier.
+ dataset_name (str): The name of the dataset. Default: 'cue_conflict'.
+ """
+
+ # mapping several classes from ImageNet-1K to the same category
+ airplane_indices = [404]
+ bear_indices = [294, 295, 296, 297]
+ bicycle_indices = [444, 671]
+ bird_indices = [
+ 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23, 24, 80, 81, 82, 83,
+ 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100, 127, 128, 129,
+ 130, 131, 132, 133, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
+ 145
+ ]
+ boat_indices = [472, 554, 625, 814, 914]
+ bottle_indices = [440, 720, 737, 898, 899, 901, 907]
+ car_indices = [436, 511, 817]
+ cat_indices = [281, 282, 283, 284, 285, 286]
+ chair_indices = [423, 559, 765, 857]
+ clock_indices = [409, 530, 892]
+ dog_indices = [
+ 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
+ 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
+ 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 193, 194,
+ 195, 196, 197, 198, 199, 200, 201, 202, 203, 205, 206, 207, 208, 209,
+ 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
+ 224, 225, 226, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
+ 239, 240, 241, 243, 244, 245, 246, 247, 248, 249, 250, 252, 253, 254,
+ 255, 256, 257, 259, 261, 262, 263, 265, 266, 267, 268
+ ]
+ elephant_indices = [385, 386]
+ keyboard_indices = [508, 878]
+ knife_indices = [499]
+ oven_indices = [766]
+ truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]
+
+ def __init__(self,
+ csv_dir: str,
+ model_name: str,
+ dataset_name: str = 'cue_conflict',
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self.categories = sorted([
+ 'knife', 'keyboard', 'elephant', 'bicycle', 'airplane', 'clock',
+ 'oven', 'chair', 'bear', 'boat', 'cat', 'bottle', 'truck', 'car',
+ 'bird', 'dog'
+ ])
+ self.csv_dir = csv_dir
+ self.model_name = model_name
+ self.dataset_name = dataset_name
+ if get_rank() == 0:
+ self.csv_path = self.create_csv()
+
+ def process(self, data_batch, data_samples: Sequence[dict]) -> None:
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ result = dict()
+ if 'pred_score' in data_sample:
+ result['pred_score'] = data_sample['pred_score'].cpu()
+ else:
+ result['pred_label'] = data_sample['pred_label'].cpu()
+ result['gt_label'] = data_sample['gt_label'].cpu()
+ result['gt_category'] = data_sample['img_path'].split('/')[-2]
+ result['img_name'] = data_sample['img_path'].split('/')[-1]
+
+ aggregated_category_probabilities = []
+ # get the prediction for each category of current instance
+ for category in self.categories:
+ category_indices = getattr(self, f'{category}_indices')
+ category_probabilities = torch.gather(
+ result['pred_score'], 0,
+ torch.tensor(category_indices)).mean()
+ aggregated_category_probabilities.append(
+ category_probabilities)
+ # sort the probabilities in descending order
+ pred_indices = torch.stack(aggregated_category_probabilities
+ ).argsort(descending=True).numpy()
+ result['pred_category'] = np.take(self.categories, pred_indices)
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def create_csv(self) -> str:
+ """Create a csv file to store the results."""
+ session_name = 'session-1'
+ csv_path = osp.join(
+ self.csv_dir, self.dataset_name + '_' + self.model_name + '_' +
+ session_name + '.csv')
+ if osp.exists(csv_path):
+ os.remove(csv_path)
+ directory = osp.dirname(csv_path)
+ if not osp.exists(directory):
+ os.makedirs(directory, exist_ok=True)
+ with open(csv_path, 'w') as f:
+ writer = csv.writer(f)
+ writer.writerow([
+ 'subj', 'session', 'trial', 'rt', 'object_response',
+ 'category', 'condition', 'imagename'
+ ])
+ return csv_path
+
+ def dump_results_to_csv(self, results: List[dict]) -> None:
+ """Dump the results to a csv file.
+
+ Args:
+ results (List[dict]): A list of results.
+ """
+ for i, result in enumerate(results):
+ img_name = result['img_name']
+ category = result['gt_category']
+ condition = 'NaN'
+ with open(self.csv_path, 'a') as f:
+ writer = csv.writer(f)
+ writer.writerow([
+ self.model_name, 1, i + 1, 'NaN',
+ result['pred_category'][0], category, condition, img_name
+ ])
+
+ def compute_metrics(self, results: List[dict]) -> dict:
+ """Compute the metrics from the results.
+
+ Args:
+ results (List[dict]): A list of results.
+
+ Returns:
+ dict: A dict of metrics.
+ """
+ if get_rank() == 0:
+ self.dump_results_to_csv(results)
+ metrics = dict()
+ metrics['accuracy/top1'] = np.mean([
+ result['pred_category'][0] == result['gt_category']
+ for result in results
+ ])
+
+ return metrics
diff --git a/mmpretrain/evaluation/metrics/single_label.py b/mmpretrain/evaluation/metrics/single_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9329b9567e698a4e3ebdb7d77f0f8404b81ad4c
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/single_label.py
@@ -0,0 +1,776 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import product
+from typing import List, Optional, Sequence, Union
+
+import mmengine
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.registry import METRICS
+
+
+def to_tensor(value):
+ """Convert value to torch.Tensor."""
+ if isinstance(value, np.ndarray):
+ value = torch.from_numpy(value)
+ elif isinstance(value, Sequence) and not mmengine.is_str(value):
+ value = torch.tensor(value)
+ elif not isinstance(value, torch.Tensor):
+ raise TypeError(f'{type(value)} is not an available argument.')
+ return value
+
+
+def _precision_recall_f1_support(pred_positive, gt_positive, average):
+ """calculate base classification task metrics, such as precision, recall,
+ f1_score, support."""
+ average_options = ['micro', 'macro', None]
+ assert average in average_options, 'Invalid `average` argument, ' \
+ f'please specify from {average_options}.'
+
+ # ignore -1 target such as difficult sample that is not wanted
+ # in evaluation results.
+ # only for calculate multi-label without affecting single-label behavior
+ ignored_index = gt_positive == -1
+ pred_positive[ignored_index] = 0
+ gt_positive[ignored_index] = 0
+
+ class_correct = (pred_positive & gt_positive)
+ if average == 'micro':
+ tp_sum = class_correct.sum()
+ pred_sum = pred_positive.sum()
+ gt_sum = gt_positive.sum()
+ else:
+ tp_sum = class_correct.sum(0)
+ pred_sum = pred_positive.sum(0)
+ gt_sum = gt_positive.sum(0)
+
+ precision = tp_sum / torch.clamp(pred_sum, min=1).float() * 100
+ recall = tp_sum / torch.clamp(gt_sum, min=1).float() * 100
+ f1_score = 2 * precision * recall / torch.clamp(
+ precision + recall, min=torch.finfo(torch.float32).eps)
+ if average in ['macro', 'micro']:
+ precision = precision.mean(0)
+ recall = recall.mean(0)
+ f1_score = f1_score.mean(0)
+ support = gt_sum.sum(0)
+ else:
+ support = gt_sum
+ return precision, recall, f1_score, support
+
+
+@METRICS.register_module()
+class Accuracy(BaseMetric):
+ r"""Accuracy evaluation metric.
+
+ For either binary classification or multi-class classification, the
+ accuracy is the fraction of correct predictions in all predictions:
+
+ .. math::
+
+ \text{Accuracy} = \frac{N_{\text{correct}}}{N_{\text{all}}}
+
+ Args:
+ topk (int | Sequence[int]): If the ground truth label matches one of
+ the best **k** predictions, the sample will be regard as a positive
+ prediction. If the parameter is a tuple, all of top-k accuracy will
+ be calculated and outputted together. Defaults to 1.
+ thrs (Sequence[float | None] | float | None): If a float, predictions
+ with score lower than the threshold will be regard as the negative
+ prediction. If None, not apply threshold. If the parameter is a
+ tuple, accuracy based on all thresholds will be calculated and
+ outputted together. Defaults to 0.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.evaluation import Accuracy
+ >>> # -------------------- The Basic Usage --------------------
+ >>> y_pred = [0, 2, 1, 3]
+ >>> y_true = [0, 1, 2, 3]
+ >>> Accuracy.calculate(y_pred, y_true)
+ tensor([50.])
+ >>> # Calculate the top1 and top5 accuracy.
+ >>> y_score = torch.rand((1000, 10))
+ >>> y_true = torch.zeros((1000, ))
+ >>> Accuracy.calculate(y_score, y_true, topk=(1, 5))
+ [[tensor([9.9000])], [tensor([51.5000])]]
+ >>>
+ >>> # ------------------- Use with Evalutor -------------------
+ >>> from mmpretrain.structures import DataSample
+ >>> from mmengine.evaluator import Evaluator
+ >>> data_samples = [
+ ... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
+ ... for i in range(1000)
+ ... ]
+ >>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(1000)
+ {
+ 'accuracy/top1': 9.300000190734863,
+ 'accuracy/top5': 51.20000076293945
+ }
+ """
+ default_prefix: Optional[str] = 'accuracy'
+
+ def __init__(self,
+ topk: Union[int, Sequence[int]] = (1, ),
+ thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ if isinstance(topk, int):
+ self.topk = (topk, )
+ else:
+ self.topk = tuple(topk)
+
+ if isinstance(thrs, float) or thrs is None:
+ self.thrs = (thrs, )
+ else:
+ self.thrs = tuple(thrs)
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+
+ for data_sample in data_samples:
+ result = dict()
+ if 'pred_score' in data_sample:
+ result['pred_score'] = data_sample['pred_score'].cpu()
+ else:
+ result['pred_label'] = data_sample['pred_label'].cpu()
+ result['gt_label'] = data_sample['gt_label'].cpu()
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method.
+ metrics = {}
+
+ # concat
+ target = torch.cat([res['gt_label'] for res in results])
+ if 'pred_score' in results[0]:
+ pred = torch.stack([res['pred_score'] for res in results])
+
+ try:
+ acc = self.calculate(pred, target, self.topk, self.thrs)
+ except ValueError as e:
+ # If the topk is invalid.
+ raise ValueError(
+ str(e) + ' Please check the `val_evaluator` and '
+ '`test_evaluator` fields in your config file.')
+
+ multi_thrs = len(self.thrs) > 1
+ for i, k in enumerate(self.topk):
+ for j, thr in enumerate(self.thrs):
+ name = f'top{k}'
+ if multi_thrs:
+ name += '_no-thr' if thr is None else f'_thr-{thr:.2f}'
+ metrics[name] = acc[i][j].item()
+ else:
+ # If only label in the `pred_label`.
+ pred = torch.cat([res['pred_label'] for res in results])
+ acc = self.calculate(pred, target, self.topk, self.thrs)
+ metrics['top1'] = acc.item()
+
+ return metrics
+
+ @staticmethod
+ def calculate(
+ pred: Union[torch.Tensor, np.ndarray, Sequence],
+ target: Union[torch.Tensor, np.ndarray, Sequence],
+ topk: Sequence[int] = (1, ),
+ thrs: Sequence[Union[float, None]] = (0., ),
+ ) -> Union[torch.Tensor, List[List[torch.Tensor]]]:
+ """Calculate the accuracy.
+
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. It can be labels (N, ), or scores of every
+ class (N, C).
+ target (torch.Tensor | np.ndarray | Sequence): The target of
+ each prediction with shape (N, ).
+ thrs (Sequence[float | None]): Predictions with scores under
+ the thresholds are considered negative. It's only used
+ when ``pred`` is scores. None means no thresholds.
+ Defaults to (0., ).
+ thrs (Sequence[float]): Predictions with scores under
+ the thresholds are considered negative. It's only used
+ when ``pred`` is scores. Defaults to (0., ).
+
+ Returns:
+ torch.Tensor | List[List[torch.Tensor]]: Accuracy.
+
+ - torch.Tensor: If the ``pred`` is a sequence of label instead of
+ score (number of dimensions is 1). Only return a top-1 accuracy
+ tensor, and ignore the argument ``topk` and ``thrs``.
+ - List[List[torch.Tensor]]: If the ``pred`` is a sequence of score
+ (number of dimensions is 2). Return the accuracy on each ``topk``
+ and ``thrs``. And the first dim is ``topk``, the second dim is
+ ``thrs``.
+ """
+
+ pred = to_tensor(pred)
+ target = to_tensor(target).to(torch.int64)
+ num = pred.size(0)
+ assert pred.size(0) == target.size(0), \
+ f"The size of pred ({pred.size(0)}) doesn't match "\
+ f'the target ({target.size(0)}).'
+
+ if pred.ndim == 1:
+ # For pred label, ignore topk and acc
+ pred_label = pred.int()
+ correct = pred.eq(target).float().sum(0, keepdim=True)
+ acc = correct.mul_(100. / num)
+ return acc
+ else:
+ # For pred score, calculate on all topk and thresholds.
+ pred = pred.float()
+ maxk = max(topk)
+
+ if maxk > pred.size(1):
+ raise ValueError(
+ f'Top-{maxk} accuracy is unavailable since the number of '
+ f'categories is {pred.size(1)}.')
+
+ pred_score, pred_label = pred.topk(maxk, dim=1)
+ pred_label = pred_label.t()
+ correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
+ results = []
+ for k in topk:
+ results.append([])
+ for thr in thrs:
+ # Only prediction values larger than thr are counted
+ # as correct
+ _correct = correct
+ if thr is not None:
+ _correct = _correct & (pred_score.t() > thr)
+ correct_k = _correct[:k].reshape(-1).float().sum(
+ 0, keepdim=True)
+ acc = correct_k.mul_(100. / num)
+ results[-1].append(acc)
+ return results
+
+
+@METRICS.register_module()
+class SingleLabelMetric(BaseMetric):
+ r"""A collection of precision, recall, f1-score and support for
+ single-label tasks.
+
+ The collection of metrics is for single-label multi-class classification.
+ And all these metrics are based on the confusion matrix of every category:
+
+ .. image:: ../../_static/image/confusion-matrix.png
+ :width: 60%
+ :align: center
+
+ All metrics can be formulated use variables above:
+
+ **Precision** is the fraction of correct predictions in all predictions:
+
+ .. math::
+ \text{Precision} = \frac{TP}{TP+FP}
+
+ **Recall** is the fraction of correct predictions in all targets:
+
+ .. math::
+ \text{Recall} = \frac{TP}{TP+FN}
+
+ **F1-score** is the harmonic mean of the precision and recall:
+
+ .. math::
+ \text{F1-score} = \frac{2\times\text{Recall}\times\text{Precision}}{\text{Recall}+\text{Precision}}
+
+ **Support** is the number of samples:
+
+ .. math::
+ \text{Support} = TP + TN + FN + FP
+
+ Args:
+ thrs (Sequence[float | None] | float | None): If a float, predictions
+ with score lower than the threshold will be regard as the negative
+ prediction. If None, only the top-1 prediction will be regard as
+ the positive prediction. If the parameter is a tuple, accuracy
+ based on all thresholds will be calculated and outputted together.
+ Defaults to 0.
+ items (Sequence[str]): The detailed metric items to evaluate, select
+ from "precision", "recall", "f1-score" and "support".
+ Defaults to ``('precision', 'recall', 'f1-score')``.
+ average (str | None): How to calculate the final metrics from the
+ confusion matrix of every category. It supports three modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories.
+ - `"micro"`: Average the confusion matrix over all categories and
+ calculate metrics on the mean confusion matrix.
+ - `None`: Calculate metrics of every category and output directly.
+
+ Defaults to "macro".
+ num_classes (int, optional): The number of classes. Defaults to None.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.evaluation import SingleLabelMetric
+ >>> # -------------------- The Basic Usage --------------------
+ >>> y_pred = [0, 1, 1, 3]
+ >>> y_true = [0, 2, 1, 3]
+ >>> # Output precision, recall, f1-score and support.
+ >>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
+ (tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4))
+ >>> # Calculate with different thresholds.
+ >>> y_score = torch.rand((1000, 10))
+ >>> y_true = torch.zeros((1000, ))
+ >>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
+ [(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)),
+ (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
+ >>>
+ >>> # ------------------- Use with Evalutor -------------------
+ >>> from mmpretrain.structures import DataSample
+ >>> from mmengine.evaluator import Evaluator
+ >>> data_samples = [
+ ... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
+ ... for i in range(1000)
+ ... ]
+ >>> evaluator = Evaluator(metrics=SingleLabelMetric())
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(1000)
+ {'single-label/precision': 19.650691986083984,
+ 'single-label/recall': 19.600000381469727,
+ 'single-label/f1-score': 19.619548797607422}
+ >>> # Evaluate on each class
+ >>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
+ >>> evaluator.process(data_samples)
+ >>> evaluator.evaluate(1000)
+ {
+ 'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
+ 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
+ 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
+ }
+ """ # noqa: E501
+ default_prefix: Optional[str] = 'single-label'
+
+ def __init__(self,
+ thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
+ items: Sequence[str] = ('precision', 'recall', 'f1-score'),
+ average: Optional[str] = 'macro',
+ num_classes: Optional[int] = None,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ if isinstance(thrs, float) or thrs is None:
+ self.thrs = (thrs, )
+ else:
+ self.thrs = tuple(thrs)
+
+ for item in items:
+ assert item in ['precision', 'recall', 'f1-score', 'support'], \
+ f'The metric {item} is not supported by `SingleLabelMetric`,' \
+ ' please specify from "precision", "recall", "f1-score" and ' \
+ '"support".'
+ self.items = tuple(items)
+ self.average = average
+ self.num_classes = num_classes
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+
+ for data_sample in data_samples:
+ result = dict()
+ if 'pred_score' in data_sample:
+ result['pred_score'] = data_sample['pred_score'].cpu()
+ else:
+ num_classes = self.num_classes or data_sample.get(
+ 'num_classes')
+ assert num_classes is not None, \
+ 'The `num_classes` must be specified if no `pred_score`.'
+ result['pred_label'] = data_sample['pred_label'].cpu()
+ result['num_classes'] = num_classes
+ result['gt_label'] = data_sample['gt_label'].cpu()
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ # NOTICE: don't access `self.results` from the method. `self.results`
+ # are a list of results from multiple batch, while the input `results`
+ # are the collected results.
+ metrics = {}
+
+ def pack_results(precision, recall, f1_score, support):
+ single_metrics = {}
+ if 'precision' in self.items:
+ single_metrics['precision'] = precision
+ if 'recall' in self.items:
+ single_metrics['recall'] = recall
+ if 'f1-score' in self.items:
+ single_metrics['f1-score'] = f1_score
+ if 'support' in self.items:
+ single_metrics['support'] = support
+ return single_metrics
+
+ # concat
+ target = torch.cat([res['gt_label'] for res in results])
+ if 'pred_score' in results[0]:
+ pred = torch.stack([res['pred_score'] for res in results])
+ metrics_list = self.calculate(
+ pred, target, thrs=self.thrs, average=self.average)
+
+ multi_thrs = len(self.thrs) > 1
+ for i, thr in enumerate(self.thrs):
+ if multi_thrs:
+ suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}'
+ else:
+ suffix = ''
+
+ for k, v in pack_results(*metrics_list[i]).items():
+ metrics[k + suffix] = v
+ else:
+ # If only label in the `pred_label`.
+ pred = torch.cat([res['pred_label'] for res in results])
+ res = self.calculate(
+ pred,
+ target,
+ average=self.average,
+ num_classes=results[0]['num_classes'])
+ metrics = pack_results(*res)
+
+ result_metrics = dict()
+ for k, v in metrics.items():
+
+ if self.average is None:
+ result_metrics[k + '_classwise'] = v.cpu().detach().tolist()
+ elif self.average == 'micro':
+ result_metrics[k + f'_{self.average}'] = v.item()
+ else:
+ result_metrics[k] = v.item()
+
+ return result_metrics
+
+ @staticmethod
+ def calculate(
+ pred: Union[torch.Tensor, np.ndarray, Sequence],
+ target: Union[torch.Tensor, np.ndarray, Sequence],
+ thrs: Sequence[Union[float, None]] = (0., ),
+ average: Optional[str] = 'macro',
+ num_classes: Optional[int] = None,
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Calculate the precision, recall, f1-score and support.
+
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. It can be labels (N, ), or scores of every
+ class (N, C).
+ target (torch.Tensor | np.ndarray | Sequence): The target of
+ each prediction with shape (N, ).
+ thrs (Sequence[float | None]): Predictions with scores under
+ the thresholds are considered negative. It's only used
+ when ``pred`` is scores. None means no thresholds.
+ Defaults to (0., ).
+ average (str | None): How to calculate the final metrics from
+ the confusion matrix of every category. It supports three
+ modes:
+
+ - `"macro"`: Calculate metrics for each category, and calculate
+ the mean value over all categories.
+ - `"micro"`: Average the confusion matrix over all categories
+ and calculate metrics on the mean confusion matrix.
+ - `None`: Calculate metrics of every category and output
+ directly.
+
+ Defaults to "macro".
+ num_classes (Optional, int): The number of classes. If the ``pred``
+ is label instead of scores, this argument is required.
+ Defaults to None.
+
+ Returns:
+ Tuple: The tuple contains precision, recall and f1-score.
+ And the type of each item is:
+
+ - torch.Tensor: If the ``pred`` is a sequence of label instead of
+ score (number of dimensions is 1). Only returns a tensor for
+ each metric. The shape is (1, ) if ``classwise`` is False, and
+ (C, ) if ``classwise`` is True.
+ - List[torch.Tensor]: If the ``pred`` is a sequence of score
+ (number of dimensions is 2). Return the metrics on each ``thrs``.
+ The shape of tensor is (1, ) if ``classwise`` is False, and (C, )
+ if ``classwise`` is True.
+ """
+ average_options = ['micro', 'macro', None]
+ assert average in average_options, 'Invalid `average` argument, ' \
+ f'please specify from {average_options}.'
+
+ pred = to_tensor(pred)
+ target = to_tensor(target).to(torch.int64)
+ assert pred.size(0) == target.size(0), \
+ f"The size of pred ({pred.size(0)}) doesn't match "\
+ f'the target ({target.size(0)}).'
+
+ if pred.ndim == 1:
+ assert num_classes is not None, \
+ 'Please specify the `num_classes` if the `pred` is labels ' \
+ 'intead of scores.'
+ gt_positive = F.one_hot(target.flatten(), num_classes)
+ pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
+ return _precision_recall_f1_support(pred_positive, gt_positive,
+ average)
+ else:
+ # For pred score, calculate on all thresholds.
+ num_classes = pred.size(1)
+ pred_score, pred_label = torch.topk(pred, k=1)
+ pred_score = pred_score.flatten()
+ pred_label = pred_label.flatten()
+
+ gt_positive = F.one_hot(target.flatten(), num_classes)
+
+ results = []
+ for thr in thrs:
+ pred_positive = F.one_hot(pred_label, num_classes)
+ if thr is not None:
+ pred_positive[pred_score <= thr] = 0
+ results.append(
+ _precision_recall_f1_support(pred_positive, gt_positive,
+ average))
+
+ return results
+
+
+@METRICS.register_module()
+class ConfusionMatrix(BaseMetric):
+ r"""A metric to calculate confusion matrix for single-label tasks.
+
+ Args:
+ num_classes (int, optional): The number of classes. Defaults to None.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+
+ Examples:
+
+ 1. The basic usage.
+
+ >>> import torch
+ >>> from mmpretrain.evaluation import ConfusionMatrix
+ >>> y_pred = [0, 1, 1, 3]
+ >>> y_true = [0, 2, 1, 3]
+ >>> ConfusionMatrix.calculate(y_pred, y_true, num_classes=4)
+ tensor([[1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1]])
+ >>> # plot the confusion matrix
+ >>> import matplotlib.pyplot as plt
+ >>> y_score = torch.rand((1000, 10))
+ >>> y_true = torch.randint(10, (1000, ))
+ >>> matrix = ConfusionMatrix.calculate(y_score, y_true)
+ >>> ConfusionMatrix().plot(matrix)
+ >>> plt.show()
+
+ 2. In the config file
+
+ .. code:: python
+
+ val_evaluator = dict(type='ConfusionMatrix')
+ test_evaluator = dict(type='ConfusionMatrix')
+ """ # noqa: E501
+ default_prefix = 'confusion_matrix'
+
+ def __init__(self,
+ num_classes: Optional[int] = None,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device, prefix)
+
+ self.num_classes = num_classes
+
+ def process(self, data_batch, data_samples: Sequence[dict]) -> None:
+ for data_sample in data_samples:
+ if 'pred_score' in data_sample:
+ pred_score = data_sample['pred_score']
+ pred_label = pred_score.argmax(dim=0, keepdim=True)
+ self.num_classes = pred_score.size(0)
+ else:
+ pred_label = data_sample['pred_label']
+
+ self.results.append({
+ 'pred_label': pred_label,
+ 'gt_label': data_sample['gt_label'],
+ })
+
+ def compute_metrics(self, results: list) -> dict:
+ pred_labels = []
+ gt_labels = []
+ for result in results:
+ pred_labels.append(result['pred_label'])
+ gt_labels.append(result['gt_label'])
+ confusion_matrix = ConfusionMatrix.calculate(
+ torch.cat(pred_labels),
+ torch.cat(gt_labels),
+ num_classes=self.num_classes)
+ return {'result': confusion_matrix}
+
+ @staticmethod
+ def calculate(pred, target, num_classes=None) -> dict:
+ """Calculate the confusion matrix for single-label task.
+
+ Args:
+ pred (torch.Tensor | np.ndarray | Sequence): The prediction
+ results. It can be labels (N, ), or scores of every
+ class (N, C).
+ target (torch.Tensor | np.ndarray | Sequence): The target of
+ each prediction with shape (N, ).
+ num_classes (Optional, int): The number of classes. If the ``pred``
+ is label instead of scores, this argument is required.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: The confusion matrix.
+ """
+ pred = to_tensor(pred)
+ target_label = to_tensor(target).int()
+
+ assert pred.size(0) == target_label.size(0), \
+ f"The size of pred ({pred.size(0)}) doesn't match "\
+ f'the target ({target_label.size(0)}).'
+ assert target_label.ndim == 1
+
+ if pred.ndim == 1:
+ assert num_classes is not None, \
+ 'Please specify the `num_classes` if the `pred` is labels ' \
+ 'intead of scores.'
+ pred_label = pred
+ else:
+ num_classes = num_classes or pred.size(1)
+ pred_label = torch.argmax(pred, dim=1).flatten()
+
+ with torch.no_grad():
+ indices = num_classes * target_label + pred_label
+ matrix = torch.bincount(indices, minlength=num_classes**2)
+ matrix = matrix.reshape(num_classes, num_classes)
+
+ return matrix
+
+ @staticmethod
+ def plot(confusion_matrix: torch.Tensor,
+ include_values: bool = False,
+ cmap: str = 'viridis',
+ classes: Optional[List[str]] = None,
+ colorbar: bool = True,
+ show: bool = True):
+ """Draw a confusion matrix by matplotlib.
+
+ Modified from `Scikit-Learn
+ `_
+
+ Args:
+ confusion_matrix (torch.Tensor): The confusion matrix to draw.
+ include_values (bool): Whether to draw the values in the figure.
+ Defaults to False.
+ cmap (str): The color map to use. Defaults to use "viridis".
+ classes (list[str], optional): The names of categories.
+ Defaults to None, which means to use index number.
+ colorbar (bool): Whether to show the colorbar. Defaults to True.
+ show (bool): Whether to show the figure immediately.
+ Defaults to True.
+ """ # noqa: E501
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(10, 10))
+
+ num_classes = confusion_matrix.size(0)
+
+ im_ = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
+ text_ = None
+ cmap_min, cmap_max = im_.cmap(0), im_.cmap(1.0)
+
+ if include_values:
+ text_ = np.empty_like(confusion_matrix, dtype=object)
+
+ # print text with appropriate color depending on background
+ thresh = (confusion_matrix.max() + confusion_matrix.min()) / 2.0
+
+ for i, j in product(range(num_classes), range(num_classes)):
+ color = cmap_max if confusion_matrix[i,
+ j] < thresh else cmap_min
+
+ text_cm = format(confusion_matrix[i, j], '.2g')
+ text_d = format(confusion_matrix[i, j], 'd')
+ if len(text_d) < len(text_cm):
+ text_cm = text_d
+
+ text_[i, j] = ax.text(
+ j, i, text_cm, ha='center', va='center', color=color)
+
+ display_labels = classes or np.arange(num_classes)
+
+ if colorbar:
+ fig.colorbar(im_, ax=ax)
+ ax.set(
+ xticks=np.arange(num_classes),
+ yticks=np.arange(num_classes),
+ xticklabels=display_labels,
+ yticklabels=display_labels,
+ ylabel='True label',
+ xlabel='Predicted label',
+ )
+ ax.invert_yaxis()
+ ax.xaxis.tick_top()
+
+ ax.set_ylim((num_classes - 0.5, -0.5))
+ # Automatically rotate the x labels.
+ fig.autofmt_xdate(ha='center')
+
+ if show:
+ plt.show()
+ return fig
diff --git a/mmpretrain/evaluation/metrics/visual_grounding_eval.py b/mmpretrain/evaluation/metrics/visual_grounding_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad16e5adf4660496b3a984087294ed9c0fee6537
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/visual_grounding_eval.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import torch
+import torchvision.ops.boxes as boxes
+from mmengine.evaluator import BaseMetric
+
+from mmpretrain.registry import METRICS
+
+
+def aligned_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor):
+ area1 = boxes.box_area(boxes1)
+ area2 = boxes.box_area(boxes2)
+
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (B, 2)
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (B, 2)
+
+ wh = boxes._upcast(rb - lt).clamp(min=0) # (B, 2)
+ inter = wh[:, 0] * wh[:, 1] # (B, )
+
+ union = area1 + area2 - inter
+ iou = inter / union
+ return iou
+
+
+@METRICS.register_module()
+class VisualGroundingMetric(BaseMetric):
+ """Visual Grounding evaluator.
+
+ Calculate the box mIOU and box grounding accuracy for visual grounding
+ model.
+
+ Args:
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+ default_prefix = 'visual-grounding'
+
+ def process(self, data_batch, data_samples):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for preds in data_samples:
+
+ pred_box = preds['pred_bboxes'].squeeze()
+ box_gt = torch.Tensor(preds['gt_bboxes']).squeeze()
+
+ result = {
+ 'box': pred_box.to('cpu').squeeze(),
+ 'box_target': box_gt.squeeze(),
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ pred_boxes = torch.stack([each['box'] for each in results])
+ gt_boxes = torch.stack([each['box_target'] for each in results])
+ iou = aligned_box_iou(pred_boxes, gt_boxes)
+ accu_num = torch.sum(iou >= 0.5)
+
+ miou = torch.mean(iou)
+ acc = accu_num / len(gt_boxes)
+ coco_val = {'miou': miou, 'acc': acc}
+ return coco_val
diff --git a/mmpretrain/evaluation/metrics/voc_multi_label.py b/mmpretrain/evaluation/metrics/voc_multi_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..1034852722796271c7ade9d75c3442cce8f1d0d1
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/voc_multi_label.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+from mmpretrain.registry import METRICS
+from mmpretrain.structures import label_to_onehot
+from .multi_label import AveragePrecision, MultiLabelMetric
+
+
+class VOCMetricMixin:
+ """A mixin class for VOC dataset metrics, VOC annotations have extra
+ `difficult` attribute for each object, therefore, extra option is needed
+ for calculating VOC metrics.
+
+ Args:
+ difficult_as_postive (Optional[bool]): Whether to map the difficult
+ labels as positive in one-hot ground truth for evaluation. If it
+ set to True, map difficult gt labels to positive ones(1), If it
+ set to False, map difficult gt labels to negative ones(0).
+ Defaults to None, the difficult labels will be set to '-1'.
+ """
+
+ def __init__(self,
+ *arg,
+ difficult_as_positive: Optional[bool] = None,
+ **kwarg):
+ self.difficult_as_positive = difficult_as_positive
+ super().__init__(*arg, **kwarg)
+
+ def process(self, data_batch, data_samples: Sequence[dict]):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ result = dict()
+ gt_label = data_sample['gt_label']
+ gt_label_difficult = data_sample['gt_label_difficult']
+
+ result['pred_score'] = data_sample['pred_score'].clone()
+ num_classes = result['pred_score'].size()[-1]
+
+ if 'gt_score' in data_sample:
+ result['gt_score'] = data_sample['gt_score'].clone()
+ else:
+ result['gt_score'] = label_to_onehot(gt_label, num_classes)
+
+ # VOC annotation labels all the objects in a single image
+ # therefore, some categories are appeared both in
+ # difficult objects and non-difficult objects.
+ # Here we reckon those labels which are only exists in difficult
+ # objects as difficult labels.
+ difficult_label = set(gt_label_difficult) - (
+ set(gt_label_difficult) & set(gt_label.tolist()))
+
+ # set difficult label for better eval
+ if self.difficult_as_positive is None:
+ result['gt_score'][[*difficult_label]] = -1
+ elif self.difficult_as_positive:
+ result['gt_score'][[*difficult_label]] = 1
+
+ # Save the result to `self.results`.
+ self.results.append(result)
+
+
+@METRICS.register_module()
+class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric):
+ """A collection of metrics for multi-label multi-class classification task
+ based on confusion matrix for VOC dataset.
+
+ It includes precision, recall, f1-score and support.
+
+ Args:
+ difficult_as_postive (Optional[bool]): Whether to map the difficult
+ labels as positive in one-hot ground truth for evaluation. If it
+ set to True, map difficult gt labels to positive ones(1), If it
+ set to False, map difficult gt labels to negative ones(0).
+ Defaults to None, the difficult labels will be set to '-1'.
+ **kwarg: Refers to `MultiLabelMetric` for detailed docstrings.
+ """
+
+
+@METRICS.register_module()
+class VOCAveragePrecision(VOCMetricMixin, AveragePrecision):
+ """Calculate the average precision with respect of classes for VOC dataset.
+
+ Args:
+ difficult_as_postive (Optional[bool]): Whether to map the difficult
+ labels as positive in one-hot ground truth for evaluation. If it
+ set to True, map difficult gt labels to positive ones(1), If it
+ set to False, map difficult gt labels to negative ones(0).
+ Defaults to None, the difficult labels will be set to '-1'.
+ **kwarg: Refers to `AveragePrecision` for detailed docstrings.
+ """
diff --git a/mmpretrain/evaluation/metrics/vqa.py b/mmpretrain/evaluation/metrics/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd77ba9bc23e013c41ac095810740bdb71d33fb3
--- /dev/null
+++ b/mmpretrain/evaluation/metrics/vqa.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Partly adopted from https://github.com/GT-Vision-Lab/VQA
+# Copyright (c) 2014, Aishwarya Agrawal
+from typing import List, Optional
+
+import mmengine
+from mmengine.evaluator import BaseMetric
+from mmengine.logging import MMLogger
+
+from mmpretrain.registry import METRICS
+
+
+def _process_punctuation(inText):
+ import re
+ outText = inText
+ punct = [
+ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
+ '>', '<', '@', '`', ',', '?', '!'
+ ]
+ commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
+ periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
+ for p in punct:
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(
+ commaStrip, inText) is not None):
+ outText = outText.replace(p, '')
+ else:
+ outText = outText.replace(p, ' ')
+ outText = periodStrip.sub('', outText, re.UNICODE)
+ return outText
+
+
+def _process_digit_article(inText):
+ outText = []
+ tempText = inText.lower().split()
+ articles = ['a', 'an', 'the']
+ manualMap = {
+ 'none': '0',
+ 'zero': '0',
+ 'one': '1',
+ 'two': '2',
+ 'three': '3',
+ 'four': '4',
+ 'five': '5',
+ 'six': '6',
+ 'seven': '7',
+ 'eight': '8',
+ 'nine': '9',
+ 'ten': '10',
+ }
+ contractions = {
+ 'aint': "ain't",
+ 'arent': "aren't",
+ 'cant': "can't",
+ 'couldve': "could've",
+ 'couldnt': "couldn't",
+ "couldn'tve": "couldn't've",
+ "couldnt've": "couldn't've",
+ 'didnt': "didn't",
+ 'doesnt': "doesn't",
+ 'dont': "don't",
+ 'hadnt': "hadn't",
+ "hadnt've": "hadn't've",
+ "hadn'tve": "hadn't've",
+ 'hasnt': "hasn't",
+ 'havent': "haven't",
+ 'hed': "he'd",
+ "hed've": "he'd've",
+ "he'dve": "he'd've",
+ 'hes': "he's",
+ 'howd': "how'd",
+ 'howll': "how'll",
+ 'hows': "how's",
+ "Id've": "I'd've",
+ "I'dve": "I'd've",
+ 'Im': "I'm",
+ 'Ive': "I've",
+ 'isnt': "isn't",
+ 'itd': "it'd",
+ "itd've": "it'd've",
+ "it'dve": "it'd've",
+ 'itll': "it'll",
+ "let's": "let's",
+ 'maam': "ma'am",
+ 'mightnt': "mightn't",
+ "mightnt've": "mightn't've",
+ "mightn'tve": "mightn't've",
+ 'mightve': "might've",
+ 'mustnt': "mustn't",
+ 'mustve': "must've",
+ 'neednt': "needn't",
+ 'notve': "not've",
+ 'oclock': "o'clock",
+ 'oughtnt': "oughtn't",
+ "ow's'at": "'ow's'at",
+ "'ows'at": "'ow's'at",
+ "'ow'sat": "'ow's'at",
+ 'shant': "shan't",
+ "shed've": "she'd've",
+ "she'dve": "she'd've",
+ "she's": "she's",
+ 'shouldve': "should've",
+ 'shouldnt': "shouldn't",
+ "shouldnt've": "shouldn't've",
+ "shouldn'tve": "shouldn't've",
+ "somebody'd": 'somebodyd',
+ "somebodyd've": "somebody'd've",
+ "somebody'dve": "somebody'd've",
+ 'somebodyll': "somebody'll",
+ 'somebodys': "somebody's",
+ 'someoned': "someone'd",
+ "someoned've": "someone'd've",
+ "someone'dve": "someone'd've",
+ 'someonell': "someone'll",
+ 'someones': "someone's",
+ 'somethingd': "something'd",
+ "somethingd've": "something'd've",
+ "something'dve": "something'd've",
+ 'somethingll': "something'll",
+ 'thats': "that's",
+ 'thered': "there'd",
+ "thered've": "there'd've",
+ "there'dve": "there'd've",
+ 'therere': "there're",
+ 'theres': "there's",
+ 'theyd': "they'd",
+ "theyd've": "they'd've",
+ "they'dve": "they'd've",
+ 'theyll': "they'll",
+ 'theyre': "they're",
+ 'theyve': "they've",
+ 'twas': "'twas",
+ 'wasnt': "wasn't",
+ "wed've": "we'd've",
+ "we'dve": "we'd've",
+ 'weve': "we've",
+ 'werent': "weren't",
+ 'whatll': "what'll",
+ 'whatre': "what're",
+ 'whats': "what's",
+ 'whatve': "what've",
+ 'whens': "when's",
+ 'whered': "where'd",
+ 'wheres': "where's",
+ 'whereve': "where've",
+ 'whod': "who'd",
+ "whod've": "who'd've",
+ "who'dve": "who'd've",
+ 'wholl': "who'll",
+ 'whos': "who's",
+ 'whove': "who've",
+ 'whyll': "why'll",
+ 'whyre': "why're",
+ 'whys': "why's",
+ 'wont': "won't",
+ 'wouldve': "would've",
+ 'wouldnt': "wouldn't",
+ "wouldnt've": "wouldn't've",
+ "wouldn'tve": "wouldn't've",
+ 'yall': "y'all",
+ "yall'll": "y'all'll",
+ "y'allll": "y'all'll",
+ "yall'd've": "y'all'd've",
+ "y'alld've": "y'all'd've",
+ "y'all'dve": "y'all'd've",
+ 'youd': "you'd",
+ "youd've": "you'd've",
+ "you'dve": "you'd've",
+ 'youll': "you'll",
+ 'youre': "you're",
+ 'youve': "you've",
+ }
+ for word in tempText:
+ word = manualMap.setdefault(word, word)
+ if word not in articles:
+ outText.append(word)
+ for wordId, word in enumerate(outText):
+ if word in contractions:
+ outText[wordId] = contractions[word]
+ outText = ' '.join(outText)
+ return outText
+
+
+@METRICS.register_module()
+class VQAAcc(BaseMetric):
+ '''VQA Acc metric.
+ Args:
+
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ '''
+ default_prefix = 'VQA'
+
+ def __init__(self,
+ full_score_weight: float = 0.3,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None):
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ self.full_score_weight = full_score_weight
+
+ def process(self, data_batch, data_samples):
+ """Process one batch of data samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to computed the metrics when all batches have been processed.
+
+ Args:
+ data_batch: A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for sample in data_samples:
+ gt_answer = sample.get('gt_answer')
+ gt_answer_weight = sample.get('gt_answer_weight')
+ if isinstance(gt_answer, str):
+ gt_answer = [gt_answer]
+ if gt_answer_weight is None:
+ gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer)
+
+ result = {
+ 'pred_answer': sample.get('pred_answer'),
+ 'gt_answer': gt_answer,
+ 'gt_answer_weight': gt_answer_weight,
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Compute the metrics from processed results.
+
+ Args:
+ results (dict): The processed results of each batch.
+
+ Returns:
+ Dict: The computed metrics. The keys are the names of the metrics,
+ and the values are corresponding results.
+ """
+ acc = []
+ for result in results:
+ pred_answer = self._process_answer(result['pred_answer'])
+ gt_answer = [
+ self._process_answer(answer) for answer in result['gt_answer']
+ ]
+ answer_weight = result['gt_answer_weight']
+
+ weight_sum = 0
+ for i, gt in enumerate(gt_answer):
+ if gt == pred_answer:
+ weight_sum += answer_weight[i]
+ vqa_acc = min(1.0, weight_sum / self.full_score_weight)
+ acc.append(vqa_acc)
+
+ accuracy = sum(acc) / len(acc) * 100
+
+ metrics = {'acc': accuracy}
+ return metrics
+
+ def _process_answer(self, answer):
+ answer = answer.replace('\n', ' ')
+ answer = answer.replace('\t', ' ')
+ answer = answer.strip()
+ answer = _process_punctuation(answer)
+ answer = _process_digit_article(answer)
+ return answer
+
+
+@METRICS.register_module()
+class ReportVQA(BaseMetric):
+ """Dump VQA result to the standard json format for VQA evaluation.
+
+ Args:
+ file_path (str): The file path to save the result file.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Should be modified according to the
+ `retrieval_type` for unambiguous results. Defaults to TR.
+ """
+ default_prefix = 'VQA'
+
+ def __init__(self,
+ file_path: str,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None):
+ super().__init__(collect_device=collect_device, prefix=prefix)
+ if not file_path.endswith('.json'):
+ raise ValueError('The output file must be a json file.')
+ self.file_path = file_path
+
+ def process(self, data_batch, data_samples) -> None:
+ """transfer tensors in predictions to CPU."""
+ for sample in data_samples:
+ question_id = sample['question_id']
+ pred_answer = sample['pred_answer']
+
+ result = {
+ 'question_id': int(question_id),
+ 'answer': pred_answer,
+ }
+
+ self.results.append(result)
+
+ def compute_metrics(self, results: List):
+ """Dump the result to json file."""
+ mmengine.dump(results, self.file_path)
+ logger = MMLogger.get_current_instance()
+ logger.info(f'Results has been saved to {self.file_path}.')
+ return {}
diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba05735b26a96cf532486f6f31d0d93bf6d30781
--- /dev/null
+++ b/mmpretrain/models/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS,
+ build_backbone, build_classifier, build_head, build_loss,
+ build_neck)
+from .classifiers import * # noqa: F401,F403
+from .heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .multimodal import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .retrievers import * # noqa: F401,F403
+from .selfsup import * # noqa: F401,F403
+from .tta import * # noqa: F401,F403
+from .utils import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone',
+ 'build_head', 'build_neck', 'build_loss', 'build_classifier'
+]
diff --git a/mmpretrain/models/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c41a876a642a73a4f788bdd3865f4259626e85c
Binary files /dev/null and b/mmpretrain/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/models/__pycache__/builder.cpython-38.pyc b/mmpretrain/models/__pycache__/builder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca6c2ea8fbc421d18eba37af09ec6a2c9296bc38
Binary files /dev/null and b/mmpretrain/models/__pycache__/builder.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e37fb7b6e15cadd0eef4a3c9c79c856fbf4247
--- /dev/null
+++ b/mmpretrain/models/backbones/__init__.py
@@ -0,0 +1,129 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+from .beit import BEiTViT
+from .conformer import Conformer
+from .convmixer import ConvMixer
+from .convnext import ConvNeXt
+from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
+from .davit import DaViT
+from .deit import DistilledVisionTransformer
+from .deit3 import DeiT3
+from .densenet import DenseNet
+from .edgenext import EdgeNeXt
+from .efficientformer import EfficientFormer
+from .efficientnet import EfficientNet
+from .efficientnet_v2 import EfficientNetV2
+from .hivit import HiViT
+from .hornet import HorNet
+from .hrnet import HRNet
+from .inception_v3 import InceptionV3
+from .lenet import LeNet5
+from .levit import LeViT
+from .mixmim import MixMIMTransformer
+from .mlp_mixer import MlpMixer
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .mobileone import MobileOne
+from .mobilevit import MobileViT
+from .mvit import MViT
+from .poolformer import PoolFormer
+from .regnet import RegNet
+from .replknet import RepLKNet
+from .repmlp import RepMLPNet
+from .repvgg import RepVGG
+from .res2net import Res2Net
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnet_cifar import ResNet_CIFAR
+from .resnext import ResNeXt
+from .revvit import RevVisionTransformer
+from .riformer import RIFormer
+from .seresnet import SEResNet
+from .seresnext import SEResNeXt
+from .shufflenet_v1 import ShuffleNetV1
+from .shufflenet_v2 import ShuffleNetV2
+from .sparse_convnext import SparseConvNeXt
+from .sparse_resnet import SparseResNet
+from .swin_transformer import SwinTransformer
+from .swin_transformer_v2 import SwinTransformerV2
+from .t2t_vit import T2T_ViT
+from .timm_backbone import TIMMBackbone
+from .tinyvit import TinyViT
+from .tnt import TNT
+from .twins import PCPVT, SVT
+from .van import VAN
+from .vgg import VGG
+from .vig import PyramidVig, Vig
+from .vision_transformer import VisionTransformer
+from .vit_eva02 import ViTEVA02
+from .vit_sam import ViTSAM
+from .xcit import XCiT
+
+__all__ = [
+ 'LeNet5',
+ 'AlexNet',
+ 'VGG',
+ 'RegNet',
+ 'ResNet',
+ 'ResNeXt',
+ 'ResNetV1d',
+ 'ResNeSt',
+ 'ResNet_CIFAR',
+ 'SEResNet',
+ 'SEResNeXt',
+ 'ShuffleNetV1',
+ 'ShuffleNetV2',
+ 'MobileNetV2',
+ 'MobileNetV3',
+ 'VisionTransformer',
+ 'SwinTransformer',
+ 'TNT',
+ 'TIMMBackbone',
+ 'T2T_ViT',
+ 'Res2Net',
+ 'RepVGG',
+ 'Conformer',
+ 'MlpMixer',
+ 'DistilledVisionTransformer',
+ 'PCPVT',
+ 'SVT',
+ 'EfficientNet',
+ 'EfficientNetV2',
+ 'ConvNeXt',
+ 'HRNet',
+ 'ResNetV1c',
+ 'ConvMixer',
+ 'EdgeNeXt',
+ 'CSPDarkNet',
+ 'CSPResNet',
+ 'CSPResNeXt',
+ 'CSPNet',
+ 'RepLKNet',
+ 'RepMLPNet',
+ 'PoolFormer',
+ 'RIFormer',
+ 'DenseNet',
+ 'VAN',
+ 'InceptionV3',
+ 'MobileOne',
+ 'EfficientFormer',
+ 'SwinTransformerV2',
+ 'MViT',
+ 'DeiT3',
+ 'HorNet',
+ 'MobileViT',
+ 'DaViT',
+ 'BEiTViT',
+ 'RevVisionTransformer',
+ 'MixMIMTransformer',
+ 'TinyViT',
+ 'LeViT',
+ 'Vig',
+ 'PyramidVig',
+ 'XCiT',
+ 'ViTSAM',
+ 'ViTEVA02',
+ 'HiViT',
+ 'SparseResNet',
+ 'SparseConvNeXt',
+]
diff --git a/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a18015ba1008f366339d501214f73c9dc36c0617
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/__init__.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8252336c00e00abc793eb11a7293a5663019042
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/alexnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a5a615152cc30549993dc9119a62e59de5f60b4
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/base_backbone.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1df3344f2f3ea918b6d0c8a5deb9c2060759012e
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/beit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30f2cfba565d203df5964dac416504857a937e26
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/conformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77489e2b4e9c55950d7826bfe177b98f09e66019
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convmixer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f2ab5471ab21c9fa8ac7b93f1b290794a6d0309
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/convnext.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e436a622fd452b53f6bb80ade3e803bb8fdd841f
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/cspnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca66272b4bfa46fc3e4b8eefe5c1046a1de38712
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/davit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fe28841c8abc0c0c88a5f0442178acef2d7a276
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b937fdd2cb043f91bb42e642285b3027fa6b51f1
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/deit3.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c503034f759b2cfbcb757d5502de1b4ffa65b601
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/densenet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ef53e8353e168a44e4916564dfaf1e0fcf6a540
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/edgenext.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa4d9250ffde11fd47f2f62147a34db35b369749
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e77b8d52c0eff923fc0eaa14b17afec772f1512
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36b782fb056e32738cf7b567a635bd06624f95f5
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/efficientnet_v2.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24e537dbf1c7edf853ed1af9ae690ed60ecd7950
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hivit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3186c22a7964393ca3b4c36f7ff9eda9bb620ff
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hornet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5a7d68884b2e20fbd361c9ce71a29ac6d32b627
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/hrnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc09fd4f405b724744e9de9ff79cbc53bd48883e
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/inception_v3.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d6ecc0ebe6386b5c9cf85df200e87afe8fd2eb8
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/lenet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..060e55647661641895839d884f02f02b2bda9069
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/levit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f18cced1cb4c0c754e2f23abfe1c0822991e96e6
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mixmim.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ced907fb4d9214d35b4565f242f3ed2ef941246c
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mlp_mixer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e63399bd69b103ad71bf826014967c7214cfa682
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v2.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24d773a2b4305cdfcae87c82d5f3dd49183fea84
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilenet_v3.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64bdcefa7ca0812e5ff42260616cb5a8adee58b0
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobileone.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7537bdce36b7680cf66ecc26bc5cb94b9bfc86b6
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mobilevit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f315d6b86d25569c9d8943b21efb26ef4e755246
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/mvit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b0bd0f5096e651c18b11a34747371fed5264378
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/poolformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..411e00066295c655a7ff6dfa02e437c32f9b5980
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/regnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ef5ee77a958c77fb22e03866d43aa42cb8088b3
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/replknet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..739b78287c2da67f7b970a24278ed5e93afddfec
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repmlp.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30ef3b932d009635d9e4b03e9c8edd6f899f6cd2
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/repvgg.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1dbcc4def3590141f489a8c95d97a5eaba60245
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/res2net.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe35027e6548be56b8d8e2b34a04260e447a9cd9
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnest.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b1d529dd69bd33fe0346b5ae7ae7708d0f1c5eb
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ba27cdca835380c50713eaf8575ce8a92e974
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnet_cifar.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..301b46e4540092cec011402493ac7a9c703cd643
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/resnext.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dd40b6e7c27389f7a5157fa58d719ecf7880530
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/revvit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26bda9f9ae80d91236aca39cc3ac448a2a779122
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/riformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6929011e417b1127ba452cc77848d68928d7f95
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09e7556b191544153b1ceb39a82ea0649d4b717c
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/seresnext.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96de19f10dfd4c35eba969159ae03ea6c66742d9
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v1.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d589fb82cae8f529c64aa309c9e8837da0330d35
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/shufflenet_v2.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..278f96e9720e827c7a33b3199fdf837299cad7c5
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_convnext.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31a1546e7fbe78b4207aa6f2f1ba91f2770500c5
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/sparse_resnet.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6207cd2f5bab62b11a9323541eae626a93ee6487
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e07f5d0b8dac62c6b17dcc130ad1ff547bd23a1e
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/swin_transformer_v2.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbff543a621817397f9afbd16e1fd678a2c1b4d2
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/t2t_vit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e3ca7b55658448bad1c96c967fb30ce0ca83e30
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/timm_backbone.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6ba4ca22ab620df1d8f85807dd7169c0bd4114c
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tinyvit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c3acc5611e9bfdf0179549fc94b9437ca05eda5
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/tnt.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08f9fc896877413f4dc9f5e78e27c80858df0074
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/twins.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f86dc3fc60bde313c4a13519c6959f1fac861152
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/van.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66383d035cf49ff7bcf1932d7a395c682ce74d2a
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vgg.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0aa1ab0b901c39b4e67e1689c9db21fd4fdee9e5
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vig.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30a69c3464095df54c17ae4d08d9e1720e321f27
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vision_transformer.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5deadcacffec2b134c8314608dc82dfd7bc87d9
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_eva02.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8129a755899798c5dafe11defd88b2f607a03f8
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/vit_sam.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc b/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4e5da85c50f64c34ec9ed8f20de44a0592900f3
Binary files /dev/null and b/mmpretrain/models/backbones/__pycache__/xcit.cpython-38.pyc differ
diff --git a/mmpretrain/models/backbones/alexnet.py b/mmpretrain/models/backbones/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7c2891fdd2c878e243331f572f6e3e562232d46
--- /dev/null
+++ b/mmpretrain/models/backbones/alexnet.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+@MODELS.register_module()
+class AlexNet(BaseBackbone):
+ """`AlexNet `_ backbone.
+
+ The input for AlexNet is a 224x224 RGB image.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ The default value is -1, which uses the backbone as
+ a feature extractor without the top classifier.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return (x, )
diff --git a/mmpretrain/models/backbones/base_backbone.py b/mmpretrain/models/backbones/base_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..751aa956ba2ad178ea9e40875b6e610ee7bbbcd3
--- /dev/null
+++ b/mmpretrain/models/backbones/base_backbone.py
@@ -0,0 +1,33 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+from mmengine.model import BaseModule
+
+
+class BaseBackbone(BaseModule, metaclass=ABCMeta):
+ """Base backbone.
+
+ This class defines the basic functions of a backbone. Any backbone that
+ inherits this class should at least define its own `forward` function.
+ """
+
+ def __init__(self, init_cfg=None):
+ super(BaseBackbone, self).__init__(init_cfg)
+
+ @abstractmethod
+ def forward(self, x):
+ """Forward computation.
+
+ Args:
+ x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
+ Torch.tensor, containing input data for forward computation.
+ """
+ pass
+
+ def train(self, mode=True):
+ """Set module status before forward computation.
+
+ Args:
+ mode (bool): Whether it is train_mode or test_mode
+ """
+ super(BaseBackbone, self).train(mode)
diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f64ae2029b8be47d938fdef25aed9c0058ef307
--- /dev/null
+++ b/mmpretrain/models/backbones/beit.py
@@ -0,0 +1,522 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
+from mmengine.model import BaseModule, ModuleList
+
+from mmpretrain.registry import MODELS
+from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
+ resize_relative_position_bias_table, to_2tuple)
+from .vision_transformer import TransformerEncoderLayer, VisionTransformer
+
+
+class RelativePositionBias(BaseModule):
+ """Relative Position Bias.
+
+ This module is copied from
+ https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py#L209.
+
+ Args:
+ window_size (Sequence[int]): The window size of the relative
+ position bias.
+ num_heads (int): The number of head in multi-head attention.
+ with_cls_token (bool): To indicate the backbone has cls_token or not.
+ Defaults to True.
+ """
+
+ def __init__(
+ self,
+ window_size: Sequence[int],
+ num_heads: int,
+ with_cls_token: bool = True,
+ ) -> None:
+ super().__init__()
+ self.window_size = window_size
+ if with_cls_token:
+ num_extra_tokens = 3
+ else:
+ num_extra_tokens = 0
+ # cls to token & token to cls & cls to cls
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1) + num_extra_tokens
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance,
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each
+ # token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] -\
+ coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ if with_cls_token:
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1, ) * 2,
+ dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(
+ -1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+ else:
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1], ) * 2,
+ dtype=relative_coords.dtype)
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ def forward(self) -> torch.Tensor:
+ # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1)
+ return relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
+ """Implements one encoder layer in BEiT.
+
+ Comparing with conventional ``TransformerEncoderLayer``, this module
+ adds weights to the shortcut connection. In addition, ``BEiTAttention``
+ is used to replace the original ``MultiheadAttention`` in
+ ``TransformerEncoderLayer``.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ layer_scale_init_value (float): The initialization value for
+ the learnable scaling of attention and FFN. 1 means no scaling.
+ drop_rate (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ window_size (tuple[int]): The height and width of the window.
+ Defaults to None.
+ use_rel_pos_bias (bool): Whether to use unique relative position bias,
+ if False, use shared relative position bias defined in backbone.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Defaults to 0.0.
+ drop_path_rate (float): Stochastic depth rate. Default 0.0.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Defaults to 2.
+ bias (bool | str): The option to add leanable bias for q, k, v. If bias
+ is True, it will add leanable bias. If bias is 'qv_bias', it will
+ only add leanable bias for q, v. If bias is False, it will not add
+ bias for q, k, v. Default to 'qv_bias'.
+ act_cfg (dict): The activation config for FFNs.
+ Defaults to ``dict(type='GELU')``.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to dict(type='LN').
+ attn_cfg (dict): The configuration for the attention layer.
+ Defaults to an empty dict.
+ ffn_cfg (dict): The configuration for the ffn layer.
+ Defaults to ``dict(add_identity=False)``.
+ init_cfg (dict or List[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims: int,
+ num_heads: int,
+ feedforward_channels: int,
+ layer_scale_init_value: float,
+ window_size: Tuple[int, int],
+ use_rel_pos_bias: bool,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ num_fcs: int = 2,
+ bias: Union[str, bool] = 'qv_bias',
+ act_cfg: dict = dict(type='GELU'),
+ norm_cfg: dict = dict(type='LN'),
+ attn_cfg: dict = dict(),
+ ffn_cfg: dict = dict(add_identity=False),
+ init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
+ super().__init__(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=feedforward_channels,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=0.,
+ drop_rate=0.,
+ num_fcs=num_fcs,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ init_cfg=init_cfg)
+
+ attn_cfg = {
+ 'window_size': window_size,
+ 'use_rel_pos_bias': use_rel_pos_bias,
+ 'qk_scale': None,
+ 'embed_dims': embed_dims,
+ 'num_heads': num_heads,
+ 'attn_drop': attn_drop_rate,
+ 'proj_drop': drop_rate,
+ 'bias': bias,
+ **attn_cfg,
+ }
+ self.attn = BEiTAttention(**attn_cfg)
+
+ ffn_cfg = {
+ 'embed_dims': embed_dims,
+ 'feedforward_channels': feedforward_channels,
+ 'num_fcs': num_fcs,
+ 'ffn_drop': drop_rate,
+ 'dropout_layer': dict(type='DropPath', drop_prob=drop_path_rate),
+ 'act_cfg': act_cfg,
+ **ffn_cfg,
+ }
+ self.ffn = FFN(**ffn_cfg)
+
+ # NOTE: drop path for stochastic depth, we shall see if
+ # this is better than dropout here
+ dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
+ self.drop_path = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ if layer_scale_init_value > 0:
+ self.gamma_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones((embed_dims)),
+ requires_grad=True)
+ self.gamma_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones((embed_dims)),
+ requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x: torch.Tensor,
+ rel_pos_bias: torch.Tensor) -> torch.Tensor:
+ if self.gamma_1 is None:
+ x = x + self.drop_path(
+ self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.ffn(self.ln2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(
+ self.ln1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x)))
+ return x
+
+
+@MODELS.register_module()
+class BEiTViT(VisionTransformer):
+ """Backbone for BEiT.
+
+ A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
+ `_
+ A PyTorch implement of : `BEiT v2: Masked Image Modeling with
+ Vector-Quantized Visual Tokenizers `_
+
+ Args:
+ arch (str | dict): BEiT architecture. If use string, choose from
+ 'base', 'large'. If use dict, it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **num_layers** (int): The number of transformer encoder layers.
+ - **num_heads** (int): The number of heads in attention modules.
+ - **feedforward_channels** (int): The hidden dimensions in
+ feedforward modules.
+
+ Defaults to 'base'.
+ img_size (int | tuple): The expected input image shape. Because we
+ support dynamic input shape, just set the argument to the most
+ common input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ in_channels (int): The num of input channels. Defaults to 3.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ bias (bool | str): The option to add leanable bias for q, k, v. If bias
+ is True, it will add leanable bias. If bias is 'qv_bias', it will
+ only add leanable bias for q, v. If bias is False, it will not add
+ bias for q, k, v. Default to 'qv_bias'.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ out_type (str): The type of output features. Please choose from
+
+ - ``"cls_token"``: The class token tensor with shape (B, C).
+ - ``"featmap"``: The feature map tensor from the patch tokens
+ with shape (B, C, H, W).
+ - ``"avg_featmap"``: The global averaged feature map tensor
+ with shape (B, C).
+ - ``"raw"``: The raw feature tensor includes patch tokens and
+ class tokens with shape (B, L, C).
+
+ Defaults to ``"avg_featmap"``.
+ with_cls_token (bool): Whether concatenating class token into image
+ tokens as transformer input. Defaults to True.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
+ Defaults to False.
+ use_rel_pos_bias (bool): Use relative position embedding in each
+ transformer encoder layer. Defaults to True.
+ use_shared_rel_pos_bias (bool): Use shared relative position embedding,
+ all transformer encoder layers share the same relative position
+ embedding. Defaults to False.
+ layer_scale_init_value (float): The initialization value for
+ the learnable scaling of attention and FFN. Defaults to 0.1.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ arch='base',
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ out_indices=-1,
+ drop_rate=0,
+ drop_path_rate=0,
+ bias='qv_bias',
+ norm_cfg=dict(type='LN', eps=1e-6),
+ final_norm=False,
+ out_type='avg_featmap',
+ with_cls_token=True,
+ frozen_stages=-1,
+ use_abs_pos_emb=False,
+ use_rel_pos_bias=True,
+ use_shared_rel_pos_bias=False,
+ layer_scale_init_value=0.1,
+ interpolate_mode='bicubic',
+ patch_cfg=dict(),
+ layer_cfgs=dict(),
+ init_cfg=None):
+ super(VisionTransformer, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
+ }
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.num_layers = self.arch_settings['num_layers']
+ self.img_size = to_2tuple(img_size)
+
+ # Set patch embedding
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+
+ # Set out type
+ if out_type not in self.OUT_TYPES:
+ raise ValueError(f'Unsupported `out_type` {out_type}, please '
+ f'choose from {self.OUT_TYPES}')
+ self.out_type = out_type
+
+ # Set cls token
+ if with_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+ self.num_extra_tokens = 1
+ elif out_type != 'cls_token':
+ self.cls_token = None
+ self.num_extra_tokens = 0
+ else:
+ raise ValueError(
+ 'with_cls_token must be True when `out_type="cls_token"`.')
+
+ # Set position embedding
+ self.interpolate_mode = interpolate_mode
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_extra_tokens,
+ self.embed_dims))
+ self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
+ else:
+ self.pos_embed = None
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
+
+ assert not (use_rel_pos_bias and use_shared_rel_pos_bias), (
+ '`use_rel_pos_bias` and `use_shared_rel_pos_bias` cannot be set '
+ 'to True at the same time')
+ self.use_rel_pos_bias = use_rel_pos_bias
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(
+ window_size=self.patch_resolution,
+ num_heads=self.arch_settings['num_heads'])
+ else:
+ self.rel_pos_bias = None
+ self._register_load_state_dict_pre_hook(
+ self._prepare_relative_position_bias_table)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_layers + index
+ assert 0 <= out_indices[i] <= self.num_layers, \
+ f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ # stochastic depth decay rule
+ dpr = np.linspace(0, drop_path_rate, self.num_layers)
+
+ self.layers = ModuleList()
+ if isinstance(layer_cfgs, dict):
+ layer_cfgs = [layer_cfgs] * self.num_layers
+ for i in range(self.num_layers):
+ _layer_cfg = dict(
+ embed_dims=self.embed_dims,
+ num_heads=self.arch_settings['num_heads'],
+ feedforward_channels=self.
+ arch_settings['feedforward_channels'],
+ layer_scale_init_value=layer_scale_init_value,
+ window_size=self.patch_resolution,
+ use_rel_pos_bias=use_rel_pos_bias,
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[i],
+ bias=bias,
+ norm_cfg=norm_cfg)
+ _layer_cfg.update(layer_cfgs[i])
+ self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
+
+ self.frozen_stages = frozen_stages
+ self.final_norm = final_norm
+ if final_norm:
+ self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
+
+ if out_type == 'avg_featmap':
+ self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
+
+ # freeze stages only when self.frozen_stages > 0
+ if self.frozen_stages > 0:
+ self._freeze_stages()
+
+ def forward(self, x):
+ B = x.shape[0]
+ x, patch_resolution = self.patch_embed(x)
+
+ if self.cls_token is not None:
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_token = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_token, x), dim=1)
+
+ if self.pos_embed is not None:
+ x = x + resize_pos_embed(
+ self.pos_embed,
+ self.patch_resolution,
+ patch_resolution,
+ mode=self.interpolate_mode,
+ num_extra_tokens=self.num_extra_tokens)
+ x = self.drop_after_pos(x)
+
+ rel_pos_bias = self.rel_pos_bias() \
+ if self.rel_pos_bias is not None else None
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x, rel_pos_bias)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.ln1(x)
+
+ if i in self.out_indices:
+ outs.append(self._format_output(x, patch_resolution))
+
+ return tuple(outs)
+
+ def _format_output(self, x, hw):
+ if self.out_type == 'raw':
+ return x
+ if self.out_type == 'cls_token':
+ return x[:, 0]
+
+ patch_token = x[:, self.num_extra_tokens:]
+ if self.out_type == 'featmap':
+ B = x.size(0)
+ # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
+ return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
+ if self.out_type == 'avg_featmap':
+ return self.ln2(patch_token.mean(dim=1))
+
+ def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
+ **kwargs):
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+
+ if self.use_rel_pos_bias and 'rel_pos_bias.relative_position_bias_table' in state_dict: # noqa:E501
+ logger.info('Expand the shared relative position embedding to '
+ 'each transformer block.')
+ rel_pos_bias = state_dict[
+ 'rel_pos_bias.relative_position_bias_table']
+ for i in range(self.num_layers):
+ state_dict[
+ f'layers.{i}.attn.relative_position_bias_table'] = \
+ rel_pos_bias.clone()
+ state_dict.pop('rel_pos_bias.relative_position_bias_table')
+ state_dict.pop('rel_pos_bias.relative_position_index')
+
+ state_dict_model = self.state_dict()
+ all_keys = list(state_dict_model.keys())
+ for key in all_keys:
+ if 'relative_position_bias_table' in key:
+ ckpt_key = prefix + key
+ if ckpt_key not in state_dict:
+ continue
+ rel_pos_bias_pretrained = state_dict[ckpt_key]
+ rel_pos_bias_current = state_dict_model[key]
+ L1, nH1 = rel_pos_bias_pretrained.size()
+ L2, nH2 = rel_pos_bias_current.size()
+ src_size = int((L1 - 3)**0.5)
+ dst_size = int((L2 - 3)**0.5)
+ if L1 != L2:
+ extra_tokens = rel_pos_bias_pretrained[-3:, :]
+ rel_pos_bias = rel_pos_bias_pretrained[:-3, :]
+
+ new_rel_pos_bias = resize_relative_position_bias_table(
+ src_size, dst_size, rel_pos_bias, nH1)
+ new_rel_pos_bias = torch.cat(
+ (new_rel_pos_bias, extra_tokens), dim=0)
+ logger.info('Resize the relative_position_bias_table from '
+ f'{state_dict[ckpt_key].shape} to '
+ f'{new_rel_pos_bias.shape}')
+ state_dict[ckpt_key] = new_rel_pos_bias
+
+ # The index buffer need to be re-generated.
+ index_buffer = ckpt_key.replace('bias_table', 'index')
+ if index_buffer in state_dict:
+ del state_dict[index_buffer]
diff --git a/mmpretrain/models/backbones/conformer.py b/mmpretrain/models/backbones/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eda72b0595b6923a7f1f563ae7186ca533f85023
--- /dev/null
+++ b/mmpretrain/models/backbones/conformer.py
@@ -0,0 +1,621 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.drop import DropPath
+from mmcv.cnn.bricks.transformer import AdaptivePadding
+from mmengine.model import BaseModule
+from mmengine.model.weight_init import trunc_normal_
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+from .vision_transformer import TransformerEncoderLayer
+
+
+class ConvBlock(BaseModule):
+ """Basic convluation block used in Conformer.
+
+ This block includes three convluation modules, and supports three new
+ functions:
+ 1. Returns the output of both the final layers and the second convluation
+ module.
+ 2. Fuses the input of the second convluation module with an extra input
+ feature map.
+ 3. Supports to add an extra convluation module to the identity connection.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ stride (int): The stride of the second convluation module.
+ Defaults to 1.
+ groups (int): The groups of the second convluation module.
+ Defaults to 1.
+ drop_path_rate (float): The rate of the DropPath layer. Defaults to 0.
+ with_residual_conv (bool): Whether to add an extra convluation module
+ to the identity connection. Defaults to False.
+ norm_cfg (dict): The config of normalization layers.
+ Defaults to ``dict(type='BN', eps=1e-6)``.
+ act_cfg (dict): The config of activative functions.
+ Defaults to ``dict(type='ReLU', inplace=True))``.
+ init_cfg (dict, optional): The extra config to initialize the module.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ groups=1,
+ drop_path_rate=0.,
+ with_residual_conv=False,
+ norm_cfg=dict(type='BN', eps=1e-6),
+ act_cfg=dict(type='ReLU', inplace=True),
+ init_cfg=None):
+ super(ConvBlock, self).__init__(init_cfg=init_cfg)
+
+ expansion = 4
+ mid_channels = out_channels // expansion
+
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1]
+ self.act1 = build_activation_layer(act_cfg)
+
+ self.conv2 = nn.Conv2d(
+ mid_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=stride,
+ groups=groups,
+ padding=1,
+ bias=False)
+ self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1]
+ self.act2 = build_activation_layer(act_cfg)
+
+ self.conv3 = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.bn3 = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act3 = build_activation_layer(act_cfg)
+
+ if with_residual_conv:
+ self.residual_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ bias=False)
+ self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1]
+
+ self.with_residual_conv = with_residual_conv
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x, fusion_features=None, out_conv2=True):
+ identity = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.conv2(x) if fusion_features is None else self.conv2(
+ x + fusion_features)
+ x = self.bn2(x)
+ x2 = self.act2(x)
+
+ x = self.conv3(x2)
+ x = self.bn3(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.with_residual_conv:
+ identity = self.residual_conv(identity)
+ identity = self.residual_bn(identity)
+
+ x += identity
+ x = self.act3(x)
+
+ if out_conv2:
+ return x, x2
+ else:
+ return x
+
+
+class FCUDown(BaseModule):
+ """CNN feature maps -> Transformer patch embeddings."""
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ down_stride,
+ with_cls_token=True,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super(FCUDown, self).__init__(init_cfg=init_cfg)
+ self.down_stride = down_stride
+ self.with_cls_token = with_cls_token
+
+ self.conv_project = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.sample_pooling = nn.AvgPool2d(
+ kernel_size=down_stride, stride=down_stride)
+
+ self.ln = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x, x_t):
+ x = self.conv_project(x) # [N, C, H, W]
+
+ x = self.sample_pooling(x).flatten(2).transpose(1, 2)
+ x = self.ln(x)
+ x = self.act(x)
+
+ if self.with_cls_token:
+ x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)
+
+ return x
+
+
+class FCUUp(BaseModule):
+ """Transformer patch embeddings -> CNN feature maps."""
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ up_stride,
+ with_cls_token=True,
+ norm_cfg=dict(type='BN', eps=1e-6),
+ act_cfg=dict(type='ReLU', inplace=True),
+ init_cfg=None):
+ super(FCUUp, self).__init__(init_cfg=init_cfg)
+
+ self.up_stride = up_stride
+ self.with_cls_token = with_cls_token
+
+ self.conv_project = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ self.bn = build_norm_layer(norm_cfg, out_channels)[1]
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x, H, W):
+ B, _, C = x.shape
+ # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
+ if self.with_cls_token:
+ x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
+ else:
+ x_r = x.transpose(1, 2).reshape(B, C, H, W)
+
+ x_r = self.act(self.bn(self.conv_project(x_r)))
+
+ return F.interpolate(
+ x_r, size=(H * self.up_stride, W * self.up_stride))
+
+
+class ConvTransBlock(BaseModule):
+ """Basic module for Conformer.
+
+ This module is a fusion of CNN block transformer encoder block.
+
+ Args:
+ in_channels (int): The number of input channels in conv blocks.
+ out_channels (int): The number of output channels in conv blocks.
+ embed_dims (int): The embedding dimension in transformer blocks.
+ conv_stride (int): The stride of conv2d layers. Defaults to 1.
+ groups (int): The groups of conv blocks. Defaults to 1.
+ with_residual_conv (bool): Whether to add a conv-bn layer to the
+ identity connect in the conv block. Defaults to False.
+ down_stride (int): The stride of the downsample pooling layer.
+ Defaults to 4.
+ num_heads (int): The number of heads in transformer attention layers.
+ Defaults to 12.
+ mlp_ratio (float): The expansion ratio in transformer FFN module.
+ Defaults to 4.
+ qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
+ with_cls_token (bool): Whether use class token or not.
+ Defaults to True.
+ drop_rate (float): The dropout rate of the output projection and
+ FFN in the transformer block. Defaults to 0.
+ attn_drop_rate (float): The dropout rate after the attention
+ calculation in the transformer block. Defaults to 0.
+ drop_path_rate (bloat): The drop path rate in both the conv block
+ and the transformer block. Defaults to 0.
+ last_fusion (bool): Whether this block is the last stage. If so,
+ downsample the fusion feature map.
+ init_cfg (dict, optional): The extra config to initialize the module.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ embed_dims,
+ conv_stride=1,
+ groups=1,
+ with_residual_conv=False,
+ down_stride=4,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ with_cls_token=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ last_fusion=False,
+ init_cfg=None):
+ super(ConvTransBlock, self).__init__(init_cfg=init_cfg)
+ expansion = 4
+ self.cnn_block = ConvBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ with_residual_conv=with_residual_conv,
+ stride=conv_stride,
+ groups=groups)
+
+ if last_fusion:
+ self.fusion_block = ConvBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ stride=2,
+ with_residual_conv=True,
+ groups=groups,
+ drop_path_rate=drop_path_rate)
+ else:
+ self.fusion_block = ConvBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ groups=groups,
+ drop_path_rate=drop_path_rate)
+
+ self.squeeze_block = FCUDown(
+ in_channels=out_channels // expansion,
+ out_channels=embed_dims,
+ down_stride=down_stride,
+ with_cls_token=with_cls_token)
+
+ self.expand_block = FCUUp(
+ in_channels=embed_dims,
+ out_channels=out_channels // expansion,
+ up_stride=down_stride,
+ with_cls_token=with_cls_token)
+
+ self.trans_block = TransformerEncoderLayer(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=int(embed_dims * mlp_ratio),
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ attn_drop_rate=attn_drop_rate,
+ qkv_bias=qkv_bias,
+ norm_cfg=dict(type='LN', eps=1e-6))
+
+ self.down_stride = down_stride
+ self.embed_dim = embed_dims
+ self.last_fusion = last_fusion
+
+ def forward(self, cnn_input, trans_input):
+ x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True)
+
+ _, _, H, W = x_conv2.shape
+
+ # Convert the feature map of conv2 to transformer embedding
+ # and concat with class token.
+ conv2_embedding = self.squeeze_block(x_conv2, trans_input)
+
+ trans_output = self.trans_block(conv2_embedding + trans_input)
+
+ # Convert the transformer output embedding to feature map
+ trans_features = self.expand_block(trans_output, H // self.down_stride,
+ W // self.down_stride)
+ x = self.fusion_block(
+ x, fusion_features=trans_features, out_conv2=False)
+
+ return x, trans_output
+
+
+@MODELS.register_module()
+class Conformer(BaseBackbone):
+ """Conformer backbone.
+
+ A PyTorch implementation of : `Conformer: Local Features Coupling Global
+ Representations for Visual Recognition `_
+
+ Args:
+ arch (str | dict): Conformer architecture. Defaults to 'tiny'.
+ patch_size (int): The patch size. Defaults to 16.
+ base_channels (int): The base number of channels in CNN network.
+ Defaults to 64.
+ mlp_ratio (float): The expansion ratio of FFN network in transformer
+ block. Defaults to 4.
+ with_cls_token (bool): Whether use class token or not.
+ Defaults to True.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'],
+ {'embed_dims': 384,
+ 'channel_ratio': 1,
+ 'num_heads': 6,
+ 'depths': 12
+ }),
+ **dict.fromkeys(['s', 'small'],
+ {'embed_dims': 384,
+ 'channel_ratio': 4,
+ 'num_heads': 6,
+ 'depths': 12
+ }),
+ **dict.fromkeys(['b', 'base'],
+ {'embed_dims': 576,
+ 'channel_ratio': 6,
+ 'num_heads': 9,
+ 'depths': 12
+ }),
+ } # yapf: disable
+
+ _version = 1
+
+ def __init__(self,
+ arch='tiny',
+ patch_size=16,
+ base_channels=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ with_cls_token=True,
+ drop_path_rate=0.,
+ norm_eval=True,
+ frozen_stages=0,
+ out_indices=-1,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'depths', 'num_heads', 'channel_ratio'
+ }
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.num_features = self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+ self.channel_ratio = self.arch_settings['channel_ratio']
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.depths + index + 1
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ self.norm_eval = norm_eval
+ self.frozen_stages = frozen_stages
+
+ self.with_cls_token = with_cls_token
+ if self.with_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+
+ # stochastic depth decay rule
+ self.trans_dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, self.depths)
+ ]
+
+ # Stem stage: get the feature maps by conv block
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False) # 1 / 2 [112, 112]
+ self.bn1 = nn.BatchNorm2d(64)
+ self.act1 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(
+ kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
+
+ assert patch_size % 16 == 0, 'The patch size of Conformer must ' \
+ 'be divisible by 16.'
+ trans_down_stride = patch_size // 4
+
+ # To solve the issue #680
+ # Auto pad the feature map to be divisible by trans_down_stride
+ self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride)
+
+ # 1 stage
+ stage1_channels = int(base_channels * self.channel_ratio)
+ self.conv_1 = ConvBlock(
+ in_channels=64,
+ out_channels=stage1_channels,
+ with_residual_conv=True,
+ stride=1)
+ self.trans_patch_conv = nn.Conv2d(
+ 64,
+ self.embed_dims,
+ kernel_size=trans_down_stride,
+ stride=trans_down_stride,
+ padding=0)
+
+ self.trans_1 = TransformerEncoderLayer(
+ embed_dims=self.embed_dims,
+ num_heads=self.num_heads,
+ feedforward_channels=int(self.embed_dims * mlp_ratio),
+ drop_path_rate=self.trans_dpr[0],
+ qkv_bias=qkv_bias,
+ norm_cfg=dict(type='LN', eps=1e-6))
+
+ # 2~4 stage
+ init_stage = 2
+ fin_stage = self.depths // 3 + 1
+ for i in range(init_stage, fin_stage):
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=stage1_channels,
+ out_channels=stage1_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=1,
+ with_residual_conv=False,
+ down_stride=trans_down_stride,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token))
+
+ stage2_channels = int(base_channels * self.channel_ratio * 2)
+ # 5~8 stage
+ init_stage = fin_stage # 5
+ fin_stage = fin_stage + self.depths // 3 # 9
+ for i in range(init_stage, fin_stage):
+ if i == init_stage:
+ conv_stride = 2
+ in_channels = stage1_channels
+ else:
+ conv_stride = 1
+ in_channels = stage2_channels
+
+ with_residual_conv = True if i == init_stage else False
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=in_channels,
+ out_channels=stage2_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=conv_stride,
+ with_residual_conv=with_residual_conv,
+ down_stride=trans_down_stride // 2,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token))
+
+ stage3_channels = int(base_channels * self.channel_ratio * 2 * 2)
+ # 9~12 stage
+ init_stage = fin_stage # 9
+ fin_stage = fin_stage + self.depths // 3 # 13
+ for i in range(init_stage, fin_stage):
+ if i == init_stage:
+ conv_stride = 2
+ in_channels = stage2_channels
+ with_residual_conv = True
+ else:
+ conv_stride = 1
+ in_channels = stage3_channels
+ with_residual_conv = False
+
+ last_fusion = (i == self.depths)
+
+ self.add_module(
+ f'conv_trans_{i}',
+ ConvTransBlock(
+ in_channels=in_channels,
+ out_channels=stage3_channels,
+ embed_dims=self.embed_dims,
+ conv_stride=conv_stride,
+ with_residual_conv=with_residual_conv,
+ down_stride=trans_down_stride // 4,
+ num_heads=self.num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path_rate=self.trans_dpr[i - 1],
+ with_cls_token=self.with_cls_token,
+ last_fusion=last_fusion))
+ self.fin_stage = fin_stage
+
+ self.pooling = nn.AdaptiveAvgPool2d(1)
+ self.trans_norm = nn.LayerNorm(self.embed_dims)
+
+ if self.with_cls_token:
+ trunc_normal_(self.cls_token, std=.02)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def init_weights(self):
+ super(Conformer, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress default init if use pretrained model.
+ return
+ self.apply(self._init_weights)
+
+ def forward(self, x):
+ output = []
+ B = x.shape[0]
+ if self.with_cls_token:
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ # stem
+ x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
+ x_base = self.auto_pad(x_base)
+
+ # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56]
+ x = self.conv_1(x_base, out_conv2=False)
+ x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
+ if self.with_cls_token:
+ x_t = torch.cat([cls_tokens, x_t], dim=1)
+ x_t = self.trans_1(x_t)
+
+ # 2 ~ final
+ for i in range(2, self.fin_stage):
+ stage = getattr(self, f'conv_trans_{i}')
+ x, x_t = stage(x, x_t)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ output.append([
+ self.pooling(x).flatten(1),
+ self.trans_norm(x_t)[:, 0]
+ ])
+ else:
+ # if no class token, use the mean patch token
+ # as the transformer feature.
+ output.append([
+ self.pooling(x).flatten(1),
+ self.trans_norm(x_t).mean(dim=1)
+ ])
+
+ return tuple(output)
diff --git a/mmpretrain/models/backbones/convmixer.py b/mmpretrain/models/backbones/convmixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..480050d5ce1aa29f190dbc24ec1413573d541cb1
--- /dev/null
+++ b/mmpretrain/models/backbones/convmixer.py
@@ -0,0 +1,176 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
+ build_norm_layer)
+from mmengine.utils import digit_version
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class Residual(nn.Module):
+
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x):
+ return self.fn(x) + x
+
+
+@MODELS.register_module()
+class ConvMixer(BaseBackbone):
+ """ConvMixer. .
+
+ A PyTorch implementation of : `Patches Are All You Need?
+ `_
+
+ Modified from the `official repo
+ `_
+ and `timm
+ `_.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``ConvMixer.arch_settings``. And if dict, it
+ should include the following two keys:
+
+ - embed_dims (int): The dimensions of patch embedding.
+ - depth (int): Number of repetitions of ConvMixer Layer.
+ - patch_size (int): The patch size.
+ - kernel_size (int): The kernel size of depthwise conv layers.
+
+ Defaults to '768/32'.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ patch_size (int): The size of one patch in the patch embed layer.
+ Defaults to 7.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation after each convolution.
+ Defaults to ``dict(type='GELU')``.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+ arch_settings = {
+ '768/32': {
+ 'embed_dims': 768,
+ 'depth': 32,
+ 'patch_size': 7,
+ 'kernel_size': 7
+ },
+ '1024/20': {
+ 'embed_dims': 1024,
+ 'depth': 20,
+ 'patch_size': 14,
+ 'kernel_size': 9
+ },
+ '1536/20': {
+ 'embed_dims': 1536,
+ 'depth': 20,
+ 'patch_size': 7,
+ 'kernel_size': 9
+ },
+ }
+
+ def __init__(self,
+ arch='768/32',
+ in_channels=3,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='GELU'),
+ out_indices=-1,
+ frozen_stages=0,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ essential_keys = {
+ 'embed_dims', 'depth', 'patch_size', 'kernel_size'
+ }
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+
+ self.embed_dims = arch['embed_dims']
+ self.depth = arch['depth']
+ self.patch_size = arch['patch_size']
+ self.kernel_size = arch['kernel_size']
+ self.act = build_activation_layer(act_cfg)
+
+ # check out indices and frozen stages
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.depth + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # Set stem layers
+ self.stem = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ self.embed_dims,
+ kernel_size=self.patch_size,
+ stride=self.patch_size), self.act,
+ build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+ # Set conv2d according to torch version
+ convfunc = nn.Conv2d
+ if digit_version(torch.__version__) < digit_version('1.9.0'):
+ convfunc = Conv2dAdaptivePadding
+
+ # Repetitions of ConvMixer Layer
+ self.stages = nn.Sequential(*[
+ nn.Sequential(
+ Residual(
+ nn.Sequential(
+ convfunc(
+ self.embed_dims,
+ self.embed_dims,
+ self.kernel_size,
+ groups=self.embed_dims,
+ padding='same'), self.act,
+ build_norm_layer(norm_cfg, self.embed_dims)[1])),
+ nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
+ self.act,
+ build_norm_layer(norm_cfg, self.embed_dims)[1])
+ for _ in range(self.depth)
+ ])
+
+ self._freeze_stages()
+
+ def forward(self, x):
+ x = self.stem(x)
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x = stage(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ # x = self.pooling(x).flatten(1)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ConvMixer, self).train(mode)
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ stage = self.stages[i]
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
diff --git a/mmpretrain/models/backbones/convnext.py b/mmpretrain/models/backbones/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a954f5b980186a86565a228669c6917bda14f68
--- /dev/null
+++ b/mmpretrain/models/backbones/convnext.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+from itertools import chain
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule, ModuleList, Sequential
+
+from mmpretrain.registry import MODELS
+from ..utils import GRN, build_norm_layer
+from .base_backbone import BaseBackbone
+
+
+class ConvNeXtBlock(BaseModule):
+ """ConvNeXt Block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ dw_conv_cfg (dict): Config of depthwise convolution.
+ Defaults to ``dict(kernel_size=7, padding=3)``.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='LN2d', eps=1e-6)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ mlp_ratio (float): The expansion ratio in both pointwise convolution.
+ Defaults to 4.
+ linear_pw_conv (bool): Whether to use linear layer to do pointwise
+ convolution. More details can be found in the note.
+ Defaults to True.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ layer_scale_init_value (float): Init value for Layer Scale.
+ Defaults to 1e-6.
+
+ Note:
+ There are two equivalent implementations:
+
+ 1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
+ all outputs are in (N, C, H, W).
+ 2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
+ -> Linear; Permute back
+
+ As default, we use the second to align with the official repository.
+ And it may be slightly faster.
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_conv_cfg=dict(kernel_size=7, padding=3),
+ norm_cfg=dict(type='LN2d', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ mlp_ratio=4.,
+ linear_pw_conv=True,
+ drop_path_rate=0.,
+ layer_scale_init_value=1e-6,
+ use_grn=False,
+ with_cp=False):
+ super().__init__()
+ self.with_cp = with_cp
+
+ self.depthwise_conv = nn.Conv2d(
+ in_channels, in_channels, groups=in_channels, **dw_conv_cfg)
+
+ self.linear_pw_conv = linear_pw_conv
+ self.norm = build_norm_layer(norm_cfg, in_channels)
+
+ mid_channels = int(mlp_ratio * in_channels)
+ if self.linear_pw_conv:
+ # Use linear layer to do pointwise conv.
+ pw_conv = nn.Linear
+ else:
+ pw_conv = partial(nn.Conv2d, kernel_size=1)
+
+ self.pointwise_conv1 = pw_conv(in_channels, mid_channels)
+ self.act = MODELS.build(act_cfg)
+ self.pointwise_conv2 = pw_conv(mid_channels, in_channels)
+
+ if use_grn:
+ self.grn = GRN(mid_channels)
+ else:
+ self.grn = None
+
+ self.gamma = nn.Parameter(
+ layer_scale_init_value * torch.ones((in_channels)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ shortcut = x
+ x = self.depthwise_conv(x)
+
+ if self.linear_pw_conv:
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x, data_format='channel_last')
+ x = self.pointwise_conv1(x)
+ x = self.act(x)
+ if self.grn is not None:
+ x = self.grn(x, data_format='channel_last')
+ x = self.pointwise_conv2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+ else:
+ x = self.norm(x, data_format='channel_first')
+ x = self.pointwise_conv1(x)
+ x = self.act(x)
+
+ if self.grn is not None:
+ x = self.grn(x, data_format='channel_first')
+ x = self.pointwise_conv2(x)
+
+ if self.gamma is not None:
+ x = x.mul(self.gamma.view(1, -1, 1, 1))
+
+ x = shortcut + self.drop_path(x)
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+ return x
+
+
+@MODELS.register_module()
+class ConvNeXt(BaseBackbone):
+ """ConvNeXt v1&v2 backbone.
+
+ A PyTorch implementation of `A ConvNet for the 2020s
+ `_ and
+ `ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
+ `_
+
+ Modified from the `official repo
+ `_
+ and `timm
+ `_.
+
+ To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
+ should include the following two keys:
+
+ - depths (list[int]): Number of blocks at each stage.
+ - channels (list[int]): The number of channels at each stage.
+
+ Defaults to 'tiny'.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ stem_patch_size (int): The size of one patch in the stem layer.
+ Defaults to 4.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='LN2d', eps=1e-6)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ linear_pw_conv (bool): Whether to use linear layer to do pointwise
+ convolution. Defaults to True.
+ use_grn (bool): Whether to add Global Response Normalization in the
+ blocks. Defaults to False.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ layer_scale_init_value (float): Init value for Layer Scale.
+ Defaults to 1e-6.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ gap_before_final_norm (bool): Whether to globally average the feature
+ map before the final norm layer. In the official repo, it's only
+ used in classification task. Defaults to True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict, optional): Initialization config dict
+ """ # noqa: E501
+ arch_settings = {
+ 'atto': {
+ 'depths': [2, 2, 6, 2],
+ 'channels': [40, 80, 160, 320]
+ },
+ 'femto': {
+ 'depths': [2, 2, 6, 2],
+ 'channels': [48, 96, 192, 384]
+ },
+ 'pico': {
+ 'depths': [2, 2, 6, 2],
+ 'channels': [64, 128, 256, 512]
+ },
+ 'nano': {
+ 'depths': [2, 2, 8, 2],
+ 'channels': [80, 160, 320, 640]
+ },
+ 'tiny': {
+ 'depths': [3, 3, 9, 3],
+ 'channels': [96, 192, 384, 768]
+ },
+ 'small': {
+ 'depths': [3, 3, 27, 3],
+ 'channels': [96, 192, 384, 768]
+ },
+ 'base': {
+ 'depths': [3, 3, 27, 3],
+ 'channels': [128, 256, 512, 1024]
+ },
+ 'large': {
+ 'depths': [3, 3, 27, 3],
+ 'channels': [192, 384, 768, 1536]
+ },
+ 'xlarge': {
+ 'depths': [3, 3, 27, 3],
+ 'channels': [256, 512, 1024, 2048]
+ },
+ 'huge': {
+ 'depths': [3, 3, 27, 3],
+ 'channels': [352, 704, 1408, 2816]
+ }
+ }
+
+ def __init__(self,
+ arch='tiny',
+ in_channels=3,
+ stem_patch_size=4,
+ norm_cfg=dict(type='LN2d', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ linear_pw_conv=True,
+ use_grn=False,
+ drop_path_rate=0.,
+ layer_scale_init_value=1e-6,
+ out_indices=-1,
+ frozen_stages=0,
+ gap_before_final_norm=True,
+ with_cp=False,
+ init_cfg=[
+ dict(
+ type='TruncNormal',
+ layer=['Conv2d', 'Linear'],
+ std=.02,
+ bias=0.),
+ dict(
+ type='Constant', layer=['LayerNorm'], val=1.,
+ bias=0.),
+ ]):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ assert 'depths' in arch and 'channels' in arch, \
+ f'The arch dict must have "depths" and "channels", ' \
+ f'but got {list(arch.keys())}.'
+
+ self.depths = arch['depths']
+ self.channels = arch['channels']
+ assert (isinstance(self.depths, Sequence)
+ and isinstance(self.channels, Sequence)
+ and len(self.depths) == len(self.channels)), \
+ f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
+ 'should be both sequence with the same length.'
+
+ self.num_stages = len(self.depths)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = 4 + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ self.frozen_stages = frozen_stages
+ self.gap_before_final_norm = gap_before_final_norm
+
+ # stochastic depth decay rule
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(self.depths))
+ ]
+ block_idx = 0
+
+ # 4 downsample layers between stages, including the stem layer.
+ self.downsample_layers = ModuleList()
+ stem = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ self.channels[0],
+ kernel_size=stem_patch_size,
+ stride=stem_patch_size),
+ build_norm_layer(norm_cfg, self.channels[0]),
+ )
+ self.downsample_layers.append(stem)
+
+ # 4 feature resolution stages, each consisting of multiple residual
+ # blocks
+ self.stages = nn.ModuleList()
+
+ for i in range(self.num_stages):
+ depth = self.depths[i]
+ channels = self.channels[i]
+
+ if i >= 1:
+ downsample_layer = nn.Sequential(
+ build_norm_layer(norm_cfg, self.channels[i - 1]),
+ nn.Conv2d(
+ self.channels[i - 1],
+ channels,
+ kernel_size=2,
+ stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ stage = Sequential(*[
+ ConvNeXtBlock(
+ in_channels=channels,
+ drop_path_rate=dpr[block_idx + j],
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ linear_pw_conv=linear_pw_conv,
+ layer_scale_init_value=layer_scale_init_value,
+ use_grn=use_grn,
+ with_cp=with_cp) for j in range(depth)
+ ])
+ block_idx += depth
+
+ self.stages.append(stage)
+
+ if i in self.out_indices:
+ norm_layer = build_norm_layer(norm_cfg, channels)
+ self.add_module(f'norm{i}', norm_layer)
+
+ self._freeze_stages()
+
+ def forward(self, x):
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x = self.downsample_layers[i](x)
+ x = stage(x)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ if self.gap_before_final_norm:
+ gap = x.mean([-2, -1], keepdim=True)
+ outs.append(norm_layer(gap).flatten(1))
+ else:
+ outs.append(norm_layer(x))
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ downsample_layer = self.downsample_layers[i]
+ stage = self.stages[i]
+ downsample_layer.eval()
+ stage.eval()
+ for param in chain(downsample_layer.parameters(),
+ stage.parameters()):
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(ConvNeXt, self).train(mode)
+ self._freeze_stages()
+
+ def get_layer_depth(self, param_name: str, prefix: str = ''):
+ """Get the layer-wise depth of a parameter.
+
+ Args:
+ param_name (str): The name of the parameter.
+ prefix (str): The prefix for the parameter.
+ Defaults to an empty string.
+
+ Returns:
+ Tuple[int, int]: The layer-wise depth and the num of layers.
+ """
+
+ max_layer_id = 12 if self.depths[-2] > 9 else 6
+
+ if not param_name.startswith(prefix):
+ # For subsequent module like head
+ return max_layer_id + 1, max_layer_id + 2
+
+ param_name = param_name[len(prefix):]
+ if param_name.startswith('downsample_layers'):
+ stage_id = int(param_name.split('.')[1])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1 or stage_id == 2:
+ layer_id = stage_id + 1
+ else: # stage_id == 3:
+ layer_id = max_layer_id
+
+ elif param_name.startswith('stages'):
+ stage_id = int(param_name.split('.')[1])
+ block_id = int(param_name.split('.')[2])
+ if stage_id == 0 or stage_id == 1:
+ layer_id = stage_id + 1
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ else: # stage_id == 3:
+ layer_id = max_layer_id
+
+ # final norm layer
+ else:
+ layer_id = max_layer_id + 1
+
+ return layer_id, max_layer_id + 2
diff --git a/mmpretrain/models/backbones/cspnet.py b/mmpretrain/models/backbones/cspnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7492e97702c28861dcce2808207a35e67f32f752
--- /dev/null
+++ b/mmpretrain/models/backbones/cspnet.py
@@ -0,0 +1,679 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule, Sequential
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from ..utils import to_ntuple
+from .resnet import Bottleneck as ResNetBottleneck
+from .resnext import Bottleneck as ResNeXtBottleneck
+
+eps = 1.0e-5
+
+
+class DarknetBottleneck(BaseModule):
+ """The basic bottleneck block used in Darknet. Each DarknetBottleneck
+ consists of two ConvModules and the input is added to the final output.
+ Each ConvModule is composed of Conv, BN, and LeakyReLU. The first convLayer
+ has filter size of 1x1 and the second one has the filter size of 3x3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the input/output channels of conv2.
+ Defaults to 4.
+ add_identity (bool): Whether to add identity to the out.
+ Defaults to True.
+ use_depthwise (bool): Whether to use depthwise separable convolution.
+ Defaults to False.
+ conv_cfg (dict): Config dict for convolution layer. Defaults to None,
+ which means using conv2d.
+ drop_path_rate (float): The ratio of the drop path layer. Default: 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN', eps=1e-5)``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='Swish')``.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=2,
+ add_identity=True,
+ use_depthwise=False,
+ conv_cfg=None,
+ drop_path_rate=0,
+ norm_cfg=dict(type='BN', eps=1e-5),
+ act_cfg=dict(type='LeakyReLU', inplace=True),
+ init_cfg=None):
+ super().__init__(init_cfg)
+ hidden_channels = int(out_channels / expansion)
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
+ self.conv1 = ConvModule(
+ in_channels,
+ hidden_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.conv2 = conv(
+ hidden_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.add_identity = \
+ add_identity and in_channels == out_channels
+
+ self.drop_path = DropPath(drop_prob=drop_path_rate
+ ) if drop_path_rate > eps else nn.Identity()
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ out = self.drop_path(out)
+
+ if self.add_identity:
+ return out + identity
+ else:
+ return out
+
+
+class CSPStage(BaseModule):
+ """Cross Stage Partial Stage.
+
+ .. code:: text
+
+ Downsample Convolution (optional)
+ |
+ |
+ Expand Convolution
+ |
+ |
+ Split to xa, xb
+ | \
+ | \
+ | blocks(xb)
+ | /
+ | / transition
+ | /
+ Concat xa, blocks(xb)
+ |
+ Transition Convolution
+
+ Args:
+ block_fn (nn.module): The basic block function in the Stage.
+ in_channels (int): The input channels of the CSP layer.
+ out_channels (int): The output channels of the CSP layer.
+ has_downsampler (bool): Whether to add a downsampler in the stage.
+ Default: False.
+ down_growth (bool): Whether to expand the channels in the
+ downsampler layer of the stage. Default: False.
+ expand_ratio (float): The expand ratio to adjust the number of
+ channels of the expand conv layer. Default: 0.5
+ bottle_ratio (float): Ratio to adjust the number of channels of the
+ hidden layer. Default: 0.5
+ block_dpr (float): The ratio of the drop path layer in the
+ blocks of the stage. Default: 0.
+ num_blocks (int): Number of blocks. Default: 1
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', inplace=True)
+ """
+
+ def __init__(self,
+ block_fn,
+ in_channels,
+ out_channels,
+ has_downsampler=True,
+ down_growth=False,
+ expand_ratio=0.5,
+ bottle_ratio=2,
+ num_blocks=1,
+ block_dpr=0,
+ block_args={},
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', eps=1e-5),
+ act_cfg=dict(type='LeakyReLU', inplace=True),
+ init_cfg=None):
+ super().__init__(init_cfg)
+ # grow downsample channels to output channels
+ down_channels = out_channels if down_growth else in_channels
+ block_dpr = to_ntuple(num_blocks)(block_dpr)
+
+ if has_downsampler:
+ self.downsample_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=down_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=32 if block_fn is ResNeXtBottleneck else 1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.downsample_conv = nn.Identity()
+
+ exp_channels = int(down_channels * expand_ratio)
+ self.expand_conv = ConvModule(
+ in_channels=down_channels,
+ out_channels=exp_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg if block_fn is DarknetBottleneck else None)
+
+ assert exp_channels % 2 == 0, \
+ 'The channel number before blocks must be divisible by 2.'
+ block_channels = exp_channels // 2
+ blocks = []
+ for i in range(num_blocks):
+ block_cfg = dict(
+ in_channels=block_channels,
+ out_channels=block_channels,
+ expansion=bottle_ratio,
+ drop_path_rate=block_dpr[i],
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **block_args)
+ blocks.append(block_fn(**block_cfg))
+ self.blocks = Sequential(*blocks)
+ self.atfer_blocks_conv = ConvModule(
+ block_channels,
+ block_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.final_conv = ConvModule(
+ 2 * block_channels,
+ out_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, x):
+ x = self.downsample_conv(x)
+ x = self.expand_conv(x)
+
+ split = x.shape[1] // 2
+ xa, xb = x[:, :split], x[:, split:]
+
+ xb = self.blocks(xb)
+ xb = self.atfer_blocks_conv(xb).contiguous()
+
+ x_final = torch.cat((xa, xb), dim=1)
+ return self.final_conv(x_final)
+
+
+class CSPNet(BaseModule):
+ """The abstract CSP Network class.
+
+ A Pytorch implementation of `CSPNet: A New Backbone that can Enhance
+ Learning Capability of CNN `_
+
+ This class is an abstract class because the Cross Stage Partial Network
+ (CSPNet) is a kind of universal network structure, and you
+ network block to implement networks like CSPResNet, CSPResNeXt and
+ CSPDarkNet.
+
+ Args:
+ arch (dict): The architecture of the CSPNet.
+ It should have the following keys:
+
+ - block_fn (Callable): A function or class to return a block
+ module, and it should accept at least ``in_channels``,
+ ``out_channels``, ``expansion``, ``drop_path_rate``, ``norm_cfg``
+ and ``act_cfg``.
+ - in_channels (Tuple[int]): The number of input channels of each
+ stage.
+ - out_channels (Tuple[int]): The number of output channels of each
+ stage.
+ - num_blocks (Tuple[int]): The number of blocks in each stage.
+ - expansion_ratio (float | Tuple[float]): The expansion ratio in
+ the expand convolution of each stage. Defaults to 0.5.
+ - bottle_ratio (float | Tuple[float]): The expansion ratio of
+ blocks in each stage. Defaults to 2.
+ - has_downsampler (bool | Tuple[bool]): Whether to add a
+ downsample convolution in each stage. Defaults to True
+ - down_growth (bool | Tuple[bool]): Whether to expand the channels
+ in the downsampler layer of each stage. Defaults to False.
+ - block_args (dict | Tuple[dict], optional): The extra arguments to
+ the blocks in each stage. Defaults to None.
+
+ stem_fn (Callable): A function or class to return a stem module.
+ And it should accept ``in_channels``.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ out_indices (int | Sequence[int]): Output from which stages.
+ Defaults to -1, which means the last stage.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ conv_cfg (dict, optional): The config dict for conv layers in blocks.
+ Defaults to None, which means use Conv2d.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='BN', eps=1e-5)``.
+ act_cfg (dict): The config dict for activation functions.
+ Defaults to ``dict(type='LeakyReLU', inplace=True)``.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ init_cfg (dict, optional): The initialization settings.
+ Defaults to ``dict(type='Kaiming', layer='Conv2d'))``.
+
+ Example:
+ >>> from functools import partial
+ >>> import torch
+ >>> import torch.nn as nn
+ >>> from mmpretrain.models import CSPNet
+ >>> from mmpretrain.models.backbones.resnet import Bottleneck
+ >>>
+ >>> # A simple example to build CSPNet.
+ >>> arch = dict(
+ ... block_fn=Bottleneck,
+ ... in_channels=[32, 64],
+ ... out_channels=[64, 128],
+ ... num_blocks=[3, 4]
+ ... )
+ >>> stem_fn = partial(nn.Conv2d, out_channels=32, kernel_size=3)
+ >>> model = CSPNet(arch=arch, stem_fn=stem_fn, out_indices=(0, 1))
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> outs = model(inputs)
+ >>> for out in outs:
+ ... print(out.shape)
+ ...
+ (1, 64, 111, 111)
+ (1, 128, 56, 56)
+ """
+
+ def __init__(self,
+ arch,
+ stem_fn,
+ in_channels=3,
+ out_indices=-1,
+ frozen_stages=-1,
+ drop_path_rate=0.,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', eps=1e-5),
+ act_cfg=dict(type='LeakyReLU', inplace=True),
+ norm_eval=False,
+ init_cfg=dict(type='Kaiming', layer='Conv2d')):
+ super().__init__(init_cfg=init_cfg)
+ self.arch = self.expand_arch(arch)
+ self.num_stages = len(self.arch['in_channels'])
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ if frozen_stages not in range(-1, self.num_stages):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{self.num_stages}). But received '
+ f'{frozen_stages}')
+ self.frozen_stages = frozen_stages
+
+ self.stem = stem_fn(in_channels)
+
+ stages = []
+ depths = self.arch['num_blocks']
+ dpr = torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
+
+ for i in range(self.num_stages):
+ stage_cfg = {k: v[i] for k, v in self.arch.items()}
+ csp_stage = CSPStage(
+ **stage_cfg,
+ block_dpr=dpr[i].tolist(),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ init_cfg=init_cfg)
+ stages.append(csp_stage)
+ self.stages = Sequential(*stages)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ out_indices = list(out_indices)
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = len(self.stages) + index
+ assert 0 <= out_indices[i] <= len(self.stages), \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ @staticmethod
+ def expand_arch(arch):
+ num_stages = len(arch['in_channels'])
+
+ def to_tuple(x, name=''):
+ if isinstance(x, (list, tuple)):
+ assert len(x) == num_stages, \
+ f'The length of {name} ({len(x)}) does not ' \
+ f'equals to the number of stages ({num_stages})'
+ return tuple(x)
+ else:
+ return (x, ) * num_stages
+
+ full_arch = {k: to_tuple(v, k) for k, v in arch.items()}
+ if 'block_args' not in full_arch:
+ full_arch['block_args'] = to_tuple({})
+ return full_arch
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+
+ for i in range(self.frozen_stages + 1):
+ m = self.stages[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(CSPNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def forward(self, x):
+ outs = []
+
+ x = self.stem(x)
+ for i, stage in enumerate(self.stages):
+ x = stage(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+
+@MODELS.register_module()
+class CSPDarkNet(CSPNet):
+ """CSP-Darknet backbone used in YOLOv4.
+
+ Args:
+ depth (int): Depth of CSP-Darknet. Default: 53.
+ in_channels (int): Number of input image channels. Default: 3.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (3, ).
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
+ mode). -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+
+ Example:
+ >>> from mmpretrain.models import CSPDarkNet
+ >>> import torch
+ >>> model = CSPDarkNet(depth=53, out_indices=(0, 1, 2, 3, 4))
+ >>> model.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 64, 208, 208)
+ (1, 128, 104, 104)
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+ arch_settings = {
+ 53:
+ dict(
+ block_fn=DarknetBottleneck,
+ in_channels=(32, 64, 128, 256, 512),
+ out_channels=(64, 128, 256, 512, 1024),
+ num_blocks=(1, 2, 8, 8, 4),
+ expand_ratio=(2, 1, 1, 1, 1),
+ bottle_ratio=(2, 1, 1, 1, 1),
+ has_downsampler=True,
+ down_growth=True,
+ ),
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ out_indices=(4, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', eps=1e-5),
+ act_cfg=dict(type='LeakyReLU', inplace=True),
+ norm_eval=False,
+ init_cfg=dict(
+ type='Kaiming',
+ layer='Conv2d',
+ a=math.sqrt(5),
+ distribution='uniform',
+ mode='fan_in',
+ nonlinearity='leaky_relu')):
+
+ assert depth in self.arch_settings, 'depth must be one of ' \
+ f'{list(self.arch_settings.keys())}, but get {depth}.'
+
+ super().__init__(
+ arch=self.arch_settings[depth],
+ stem_fn=self._make_stem_layer,
+ in_channels=in_channels,
+ out_indices=out_indices,
+ frozen_stages=frozen_stages,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ norm_eval=norm_eval,
+ init_cfg=init_cfg)
+
+ def _make_stem_layer(self, in_channels):
+ """using a stride=1 conv as the stem in CSPDarknet."""
+ # `stem_channels` equals to the `in_channels` in the first stage.
+ stem_channels = self.arch['in_channels'][0]
+ stem = ConvModule(
+ in_channels=in_channels,
+ out_channels=stem_channels,
+ kernel_size=3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ return stem
+
+
+@MODELS.register_module()
+class CSPResNet(CSPNet):
+ """CSP-ResNet backbone.
+
+ Args:
+ depth (int): Depth of CSP-ResNet. Default: 50.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (4, ).
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
+ mode). -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ Example:
+ >>> from mmpretrain.models import CSPResNet
+ >>> import torch
+ >>> model = CSPResNet(depth=50, out_indices=(0, 1, 2, 3))
+ >>> model.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 128, 104, 104)
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+ arch_settings = {
+ 50:
+ dict(
+ block_fn=ResNetBottleneck,
+ in_channels=(64, 128, 256, 512),
+ out_channels=(128, 256, 512, 1024),
+ num_blocks=(3, 3, 5, 2),
+ expand_ratio=4,
+ bottle_ratio=2,
+ has_downsampler=(False, True, True, True),
+ down_growth=False),
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ deep_stem=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', eps=1e-5),
+ act_cfg=dict(type='LeakyReLU', inplace=True),
+ norm_eval=False,
+ init_cfg=dict(type='Kaiming', layer='Conv2d')):
+ assert depth in self.arch_settings, 'depth must be one of ' \
+ f'{list(self.arch_settings.keys())}, but get {depth}.'
+ self.deep_stem = deep_stem
+
+ super().__init__(
+ arch=self.arch_settings[depth],
+ stem_fn=self._make_stem_layer,
+ in_channels=in_channels,
+ out_indices=out_indices,
+ frozen_stages=frozen_stages,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ norm_eval=norm_eval,
+ init_cfg=init_cfg)
+
+ def _make_stem_layer(self, in_channels):
+ # `stem_channels` equals to the `in_channels` in the first stage.
+ stem_channels = self.arch['in_channels'][0]
+ if self.deep_stem:
+ stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ else:
+ stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+ return stem
+
+
+@MODELS.register_module()
+class CSPResNeXt(CSPResNet):
+ """CSP-ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of CSP-ResNeXt. Default: 50.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (4, ).
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
+ mode). -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ Example:
+ >>> from mmpretrain.models import CSPResNeXt
+ >>> import torch
+ >>> model = CSPResNeXt(depth=50, out_indices=(0, 1, 2, 3))
+ >>> model.eval()
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 256, 56, 56)
+ (1, 512, 28, 28)
+ (1, 1024, 14, 14)
+ (1, 2048, 7, 7)
+ """
+ arch_settings = {
+ 50:
+ dict(
+ block_fn=ResNeXtBottleneck,
+ in_channels=(64, 256, 512, 1024),
+ out_channels=(256, 512, 1024, 2048),
+ num_blocks=(3, 3, 5, 2),
+ expand_ratio=(4, 2, 2, 2),
+ bottle_ratio=4,
+ has_downsampler=(False, True, True, True),
+ down_growth=False,
+ # the base_channels is changed from 64 to 32 in CSPNet
+ block_args=dict(base_channels=32),
+ ),
+ }
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
diff --git a/mmpretrain/models/backbones/davit.py b/mmpretrain/models/backbones/davit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf25e2ed7137fb403e38801b50b355c4306331d6
--- /dev/null
+++ b/mmpretrain/models/backbones/davit.py
@@ -0,0 +1,834 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from typing import Sequence, Tuple
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.cnn.bricks import Conv2d
+from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed
+from mmengine.model import BaseModule, ModuleList
+from mmengine.utils import to_2tuple
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpretrain.models.backbones.base_backbone import BaseBackbone
+from mmpretrain.registry import MODELS
+from ..utils import ShiftWindowMSA
+
+
+class DaViTWindowMSA(BaseModule):
+ """Window based multi-head self-attention (W-MSA) module for DaViT.
+
+ The differences between DaViTWindowMSA & WindowMSA:
+ 1. Without relative position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ attn_drop (float, optional): Dropout ratio of attention weight.
+ Defaults to 0.
+ proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ init_cfg=None):
+
+ super().__init__(init_cfg)
+ self.embed_dims = embed_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+
+ x (tensor): input features with shape of (num_windows*B, N, C)
+ mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
+ Wh*Ww), value should be between (-inf, 0].
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class ConvPosEnc(BaseModule):
+ """DaViT conv pos encode block.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ kernel_size (int): The kernel size of the first convolution.
+ Defaults to 3.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, embed_dims, kernel_size=3, init_cfg=None):
+ super(ConvPosEnc, self).__init__(init_cfg)
+ self.proj = Conv2d(
+ embed_dims,
+ embed_dims,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=embed_dims)
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+ H, W = size
+ assert N == H * W
+
+ feat = x.transpose(1, 2).view(B, C, H, W)
+ feat = self.proj(feat)
+ feat = feat.flatten(2).transpose(1, 2)
+ x = x + feat
+ return x
+
+
+class DaViTDownSample(BaseModule):
+ """DaViT down sampole block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ conv_type (str): The type of convolution
+ to generate patch embedding. Default: "Conv2d".
+ kernel_size (int): The kernel size of the first convolution.
+ Defaults to 2.
+ stride (int): The stride of the second convluation module.
+ Defaults to 2.
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Defaults to "corner".
+ dilation (int): Dilation of the convolution layers. Defaults to 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ conv_type='Conv2d',
+ kernel_size=2,
+ stride=2,
+ padding='same',
+ dilation=1,
+ bias=True,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.out_channels = out_channels
+ if stride is None:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adaptive_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of conv
+ padding = 0
+ else:
+ self.adaptive_padding = None
+ padding = to_2tuple(padding)
+
+ self.projection = build_conv_layer(
+ dict(type=conv_type),
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, in_channels)[1]
+ else:
+ self.norm = None
+
+ def forward(self, x, input_size):
+ if self.adaptive_padding:
+ x = self.adaptive_padding(x)
+ H, W = input_size
+ B, L, C = x.shape
+ assert L == H * W, 'input feature has wrong size'
+
+ x = self.norm(x)
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+
+ x = self.projection(x)
+ output_size = (x.size(2), x.size(3))
+ x = x.flatten(2).transpose(1, 2)
+ return x, output_size
+
+
+class ChannelAttention(BaseModule):
+ """DaViT channel attention.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None):
+ super().__init__(init_cfg)
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.head_dims = embed_dims // num_heads
+ self.scale = self.head_dims**-0.5
+
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+
+ def forward(self, x):
+ B, N, _ = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ self.head_dims).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ k = k * self.scale
+ attention = k.transpose(-1, -2) @ v
+ attention = attention.softmax(dim=-1)
+
+ x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
+ x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
+ x = self.proj(x)
+ return x
+
+
+class ChannelBlock(BaseModule):
+ """DaViT channel attention block.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window. Defaults to 7.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ drop_path (float): The drop path rate after attention and ffn.
+ Defaults to 0.
+ ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
+ norm_cfg (dict): The config of norm layers.
+ Defaults to ``dict(type='LN')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ ffn_ratio=4.,
+ qkv_bias=False,
+ drop_path=0.,
+ ffn_cfgs=dict(),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ self.with_cp = with_cp
+
+ self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.attn = ChannelAttention(
+ embed_dims, num_heads=num_heads, qkv_bias=qkv_bias)
+ self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
+
+ _ffn_cfgs = {
+ 'embed_dims': embed_dims,
+ 'feedforward_channels': int(embed_dims * ffn_ratio),
+ 'num_fcs': 2,
+ 'ffn_drop': 0,
+ 'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
+ 'act_cfg': dict(type='GELU'),
+ **ffn_cfgs
+ }
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(**_ffn_cfgs)
+
+ def forward(self, x, hw_shape):
+
+ def _inner_forward(x):
+ x = self.cpe1(x, hw_shape)
+ identity = x
+ x = self.norm1(x)
+ x = self.attn(x)
+ x = x + identity
+
+ x = self.cpe2(x, hw_shape)
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+
+ return x
+
+
+class SpatialBlock(BaseModule):
+ """DaViT spatial attention block.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window. Defaults to 7.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ drop_path (float): The drop path rate after attention and ffn.
+ Defaults to 0.
+ pad_small_map (bool): If True, pad the small feature map to the window
+ size, which is common used in detection and segmentation. If False,
+ avoid shifting window and shrink the window size to the size of
+ feature map, which is common used in classification.
+ Defaults to False.
+ attn_cfgs (dict): The extra config of Shift Window-MSA.
+ Defaults to empty dict.
+ ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
+ norm_cfg (dict): The config of norm layers.
+ Defaults to ``dict(type='LN')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size=7,
+ ffn_ratio=4.,
+ qkv_bias=True,
+ drop_path=0.,
+ pad_small_map=False,
+ attn_cfgs=dict(),
+ ffn_cfgs=dict(),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+
+ super(SpatialBlock, self).__init__(init_cfg)
+ self.with_cp = with_cp
+
+ self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ _attn_cfgs = {
+ 'embed_dims': embed_dims,
+ 'num_heads': num_heads,
+ 'shift_size': 0,
+ 'window_size': window_size,
+ 'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
+ 'qkv_bias': qkv_bias,
+ 'pad_small_map': pad_small_map,
+ 'window_msa': DaViTWindowMSA,
+ **attn_cfgs
+ }
+ self.attn = ShiftWindowMSA(**_attn_cfgs)
+ self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
+
+ _ffn_cfgs = {
+ 'embed_dims': embed_dims,
+ 'feedforward_channels': int(embed_dims * ffn_ratio),
+ 'num_fcs': 2,
+ 'ffn_drop': 0,
+ 'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
+ 'act_cfg': dict(type='GELU'),
+ **ffn_cfgs
+ }
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(**_ffn_cfgs)
+
+ def forward(self, x, hw_shape):
+
+ def _inner_forward(x):
+ x = self.cpe1(x, hw_shape)
+ identity = x
+ x = self.norm1(x)
+ x = self.attn(x, hw_shape)
+ x = x + identity
+
+ x = self.cpe2(x, hw_shape)
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+
+ return x
+
+
+class DaViTBlock(BaseModule):
+ """DaViT block.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window. Defaults to 7.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ drop_path (float): The drop path rate after attention and ffn.
+ Defaults to 0.
+ pad_small_map (bool): If True, pad the small feature map to the window
+ size, which is common used in detection and segmentation. If False,
+ avoid shifting window and shrink the window size to the size of
+ feature map, which is common used in classification.
+ Defaults to False.
+ attn_cfgs (dict): The extra config of Shift Window-MSA.
+ Defaults to empty dict.
+ ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
+ norm_cfg (dict): The config of norm layers.
+ Defaults to ``dict(type='LN')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size=7,
+ ffn_ratio=4.,
+ qkv_bias=True,
+ drop_path=0.,
+ pad_small_map=False,
+ attn_cfgs=dict(),
+ ffn_cfgs=dict(),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+
+ super(DaViTBlock, self).__init__(init_cfg)
+ self.spatial_block = SpatialBlock(
+ embed_dims,
+ num_heads,
+ window_size=window_size,
+ ffn_ratio=ffn_ratio,
+ qkv_bias=qkv_bias,
+ drop_path=drop_path,
+ pad_small_map=pad_small_map,
+ attn_cfgs=attn_cfgs,
+ ffn_cfgs=ffn_cfgs,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp)
+ self.channel_block = ChannelBlock(
+ embed_dims,
+ num_heads,
+ ffn_ratio=ffn_ratio,
+ qkv_bias=qkv_bias,
+ drop_path=drop_path,
+ ffn_cfgs=ffn_cfgs,
+ norm_cfg=norm_cfg,
+ with_cp=False)
+
+ def forward(self, x, hw_shape):
+ x = self.spatial_block(x, hw_shape)
+ x = self.channel_block(x, hw_shape)
+
+ return x
+
+
+class DaViTBlockSequence(BaseModule):
+ """Module with successive DaViT blocks and downsample layer.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ depth (int): Number of successive DaViT blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window. Defaults to 7.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ downsample (bool): Downsample the output of blocks by patch merging.
+ Defaults to False.
+ downsample_cfg (dict): The extra config of the patch merging layer.
+ Defaults to empty dict.
+ drop_paths (Sequence[float] | float): The drop path rate in each block.
+ Defaults to 0.
+ block_cfgs (Sequence[dict] | dict): The extra config of each block.
+ Defaults to empty dicts.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ pad_small_map (bool): If True, pad the small feature map to the window
+ size, which is common used in detection and segmentation. If False,
+ avoid shifting window and shrink the window size to the size of
+ feature map, which is common used in classification.
+ Defaults to False.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ depth,
+ num_heads,
+ window_size=7,
+ ffn_ratio=4.,
+ qkv_bias=True,
+ downsample=False,
+ downsample_cfg=dict(),
+ drop_paths=0.,
+ block_cfgs=dict(),
+ with_cp=False,
+ pad_small_map=False,
+ init_cfg=None):
+ super().__init__(init_cfg)
+
+ if not isinstance(drop_paths, Sequence):
+ drop_paths = [drop_paths] * depth
+
+ if not isinstance(block_cfgs, Sequence):
+ block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
+
+ self.embed_dims = embed_dims
+ self.blocks = ModuleList()
+ for i in range(depth):
+ _block_cfg = {
+ 'embed_dims': embed_dims,
+ 'num_heads': num_heads,
+ 'window_size': window_size,
+ 'ffn_ratio': ffn_ratio,
+ 'qkv_bias': qkv_bias,
+ 'drop_path': drop_paths[i],
+ 'with_cp': with_cp,
+ 'pad_small_map': pad_small_map,
+ **block_cfgs[i]
+ }
+ block = DaViTBlock(**_block_cfg)
+ self.blocks.append(block)
+
+ if downsample:
+ _downsample_cfg = {
+ 'in_channels': embed_dims,
+ 'out_channels': 2 * embed_dims,
+ 'norm_cfg': dict(type='LN'),
+ **downsample_cfg
+ }
+ self.downsample = DaViTDownSample(**_downsample_cfg)
+ else:
+ self.downsample = None
+
+ def forward(self, x, in_shape, do_downsample=True):
+ for block in self.blocks:
+ x = block(x, in_shape)
+
+ if self.downsample is not None and do_downsample:
+ x, out_shape = self.downsample(x, in_shape)
+ else:
+ out_shape = in_shape
+ return x, out_shape
+
+ @property
+ def out_channels(self):
+ if self.downsample:
+ return self.downsample.out_channels
+ else:
+ return self.embed_dims
+
+
+@MODELS.register_module()
+class DaViT(BaseBackbone):
+ """DaViT.
+
+ A PyTorch implement of : `DaViT: Dual Attention Vision Transformers
+ `_
+
+ Inspiration from
+ https://github.com/dingmyu/davit
+
+ Args:
+ arch (str | dict): DaViT architecture. If use string, choose from
+ 'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict,
+ it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **depths** (List[int]): The number of blocks in each stage.
+ - **num_heads** (List[int]): The number of heads in attention
+ modules of each stage.
+
+ Defaults to 't'.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 4.
+ in_channels (int): The num of input channels. Defaults to 3.
+ window_size (int): The height and width of the window. Defaults to 7.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+ out_after_downsample (bool): Whether to output the feature map of a
+ stage after the following downsample layer. Defaults to False.
+ pad_small_map (bool): If True, pad the small feature map to the window
+ size, which is common used in detection and segmentation. If False,
+ avoid shifting window and shrink the window size to the size of
+ feature map, which is common used in classification.
+ Defaults to False.
+ norm_cfg (dict): Config dict for normalization layer for all output
+ features. Defaults to ``dict(type='LN')``
+ stage_cfgs (Sequence[dict] | dict): Extra config dict for each
+ stage. Defaults to an empty dict.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'], {
+ 'embed_dims': 96,
+ 'depths': [1, 1, 3, 1],
+ 'num_heads': [3, 6, 12, 24]
+ }),
+ **dict.fromkeys(['s', 'small'], {
+ 'embed_dims': 96,
+ 'depths': [1, 1, 9, 1],
+ 'num_heads': [3, 6, 12, 24]
+ }),
+ **dict.fromkeys(['b', 'base'], {
+ 'embed_dims': 128,
+ 'depths': [1, 1, 9, 1],
+ 'num_heads': [4, 8, 16, 32]
+ }),
+ **dict.fromkeys(
+ ['l', 'large'], {
+ 'embed_dims': 192,
+ 'depths': [1, 1, 9, 1],
+ 'num_heads': [6, 12, 24, 48]
+ }),
+ **dict.fromkeys(
+ ['h', 'huge'], {
+ 'embed_dims': 256,
+ 'depths': [1, 1, 9, 1],
+ 'num_heads': [8, 16, 32, 64]
+ }),
+ **dict.fromkeys(
+ ['g', 'giant'], {
+ 'embed_dims': 384,
+ 'depths': [1, 1, 12, 3],
+ 'num_heads': [12, 24, 48, 96]
+ }),
+ }
+
+ def __init__(self,
+ arch='t',
+ patch_size=4,
+ in_channels=3,
+ window_size=7,
+ ffn_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.1,
+ out_after_downsample=False,
+ pad_small_map=False,
+ norm_cfg=dict(type='LN'),
+ stage_cfgs=dict(),
+ frozen_stages=-1,
+ norm_eval=False,
+ out_indices=(3, ),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'embed_dims', 'depths', 'num_heads'}
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+ self.num_layers = len(self.depths)
+ self.out_indices = out_indices
+ self.out_after_downsample = out_after_downsample
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+
+ # stochastic depth decay rule
+ total_depth = sum(self.depths)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ] # stochastic depth decay rule
+
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=7,
+ stride=patch_size,
+ padding='same',
+ norm_cfg=dict(type='LN'),
+ )
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+
+ self.stages = ModuleList()
+ embed_dims = [self.embed_dims]
+ for i, (depth,
+ num_heads) in enumerate(zip(self.depths, self.num_heads)):
+ if isinstance(stage_cfgs, Sequence):
+ stage_cfg = stage_cfgs[i]
+ else:
+ stage_cfg = deepcopy(stage_cfgs)
+ downsample = True if i < self.num_layers - 1 else False
+ _stage_cfg = {
+ 'embed_dims': embed_dims[-1],
+ 'depth': depth,
+ 'num_heads': num_heads,
+ 'window_size': window_size,
+ 'ffn_ratio': ffn_ratio,
+ 'qkv_bias': qkv_bias,
+ 'downsample': downsample,
+ 'drop_paths': dpr[:depth],
+ 'with_cp': with_cp,
+ 'pad_small_map': pad_small_map,
+ **stage_cfg
+ }
+
+ stage = DaViTBlockSequence(**_stage_cfg)
+ self.stages.append(stage)
+
+ dpr = dpr[depth:]
+ embed_dims.append(stage.out_channels)
+
+ self.num_features = embed_dims[:-1]
+
+ # add a norm layer for each output
+ for i in out_indices:
+ if norm_cfg is not None:
+ norm_layer = build_norm_layer(norm_cfg,
+ self.num_features[i])[1]
+ else:
+ norm_layer = nn.Identity()
+
+ self.add_module(f'norm{i}', norm_layer)
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(0, self.frozen_stages + 1):
+ m = self.stages[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+ for i in self.out_indices:
+ if i <= self.frozen_stages:
+ for param in getattr(self, f'norm{i}').parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ x, hw_shape = self.patch_embed(x)
+
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x, hw_shape = stage(
+ x, hw_shape, do_downsample=self.out_after_downsample)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ out = norm_layer(x)
+ out = out.view(-1, *hw_shape,
+ self.num_features[i]).permute(0, 3, 1,
+ 2).contiguous()
+ outs.append(out)
+ if stage.downsample is not None and not self.out_after_downsample:
+ x, hw_shape = stage.downsample(x, hw_shape)
+
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/deit.py b/mmpretrain/models/backbones/deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae340829bece31536d0c0ac119ffe635bce82e0
--- /dev/null
+++ b/mmpretrain/models/backbones/deit.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmengine.model.weight_init import trunc_normal_
+
+from mmpretrain.registry import MODELS
+from .vision_transformer import VisionTransformer
+
+
+@MODELS.register_module()
+class DistilledVisionTransformer(VisionTransformer):
+ """Distilled Vision Transformer.
+
+ A PyTorch implement of : `Training data-efficient image transformers &
+ distillation through attention `_
+
+ Args:
+ arch (str | dict): Vision Transformer architecture. If use string,
+ choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
+ and 'deit-base'. If use dict, it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **num_layers** (int): The number of transformer encoder layers.
+ - **num_heads** (int): The number of heads in attention modules.
+ - **feedforward_channels** (int): The hidden dimensions in
+ feedforward modules.
+
+ Defaults to 'deit-base'.
+ img_size (int | tuple): The expected input image shape. Because we
+ support dynamic input shape, just set the argument to the most
+ common input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ in_channels (int): The num of input channels. Defaults to 3.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ out_type (str): The type of output features. Please choose from
+
+ - ``"cls_token"``: A tuple with the class token and the
+ distillation token. The shapes of both tensor are (B, C).
+ - ``"featmap"``: The feature map tensor from the patch tokens
+ with shape (B, C, H, W).
+ - ``"avg_featmap"``: The global averaged feature map tensor
+ with shape (B, C).
+ - ``"raw"``: The raw feature tensor includes patch tokens and
+ class tokens with shape (B, L, C).
+
+ Defaults to ``"cls_token"``.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ num_extra_tokens = 2 # class token and distillation token
+
+ def __init__(self, arch='deit-base', *args, **kwargs):
+ super(DistilledVisionTransformer, self).__init__(
+ arch=arch,
+ with_cls_token=True,
+ *args,
+ **kwargs,
+ )
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+
+ def forward(self, x):
+ B = x.shape[0]
+ x, patch_resolution = self.patch_embed(x)
+
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ x = x + self.resize_pos_embed(
+ self.pos_embed,
+ self.patch_resolution,
+ patch_resolution,
+ mode=self.interpolate_mode,
+ num_extra_tokens=self.num_extra_tokens)
+ x = self.drop_after_pos(x)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.ln1(x)
+
+ if i in self.out_indices:
+ outs.append(self._format_output(x, patch_resolution))
+
+ return tuple(outs)
+
+ def _format_output(self, x, hw):
+ if self.out_type == 'cls_token':
+ return x[:, 0], x[:, 1]
+
+ return super()._format_output(x, hw)
+
+ def init_weights(self):
+ super(DistilledVisionTransformer, self).init_weights()
+
+ if not (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ trunc_normal_(self.dist_token, std=0.02)
diff --git a/mmpretrain/models/backbones/deit3.py b/mmpretrain/models/backbones/deit3.py
new file mode 100644
index 0000000000000000000000000000000000000000..acedabe42d66a8073f34b1b0ae87501522fcc1b5
--- /dev/null
+++ b/mmpretrain/models/backbones/deit3.py
@@ -0,0 +1,454 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence
+
+import numpy as np
+import torch
+from mmcv.cnn import Linear, build_activation_layer
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.transformer import PatchEmbed
+from mmengine.model import BaseModule, ModuleList, Sequential
+from mmengine.utils import deprecated_api_warning
+from torch import nn
+
+from mmpretrain.registry import MODELS
+from ..utils import (LayerScale, MultiheadAttention, build_norm_layer,
+ resize_pos_embed, to_2tuple)
+from .vision_transformer import VisionTransformer
+
+
+class DeiT3FFN(BaseModule):
+ """FFN for DeiT3.
+
+ The differences between DeiT3FFN & FFN:
+ 1. Use LayerScale.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ use_layer_scale (bool): Whether to use layer_scale in
+ DeiT3FFN. Defaults to True.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ use_layer_scale=True,
+ init_cfg=None,
+ **kwargs):
+ super().__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ if use_layer_scale:
+ self.gamma2 = LayerScale(embed_dims)
+ else:
+ self.gamma2 = nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ out = self.gamma2(out)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+class DeiT3TransformerEncoderLayer(BaseModule):
+ """Implements one encoder layer in DeiT3.
+
+ The differences between DeiT3TransformerEncoderLayer &
+ TransformerEncoderLayer:
+ 1. Use LayerScale.
+
+ Args:
+ embed_dims (int): The feature dimension
+ num_heads (int): Parallel attention heads
+ feedforward_channels (int): The hidden dimension for FFNs
+ drop_rate (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ attn_drop_rate (float): The drop out rate for attention output weights.
+ Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Defaults to 2.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ use_layer_scale (bool): Whether to use layer_scale in
+ DeiT3TransformerEncoderLayer. Defaults to True.
+ act_cfg (dict): The activation config for FFNs.
+ Defaults to ``dict(type='GELU')``.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ num_fcs=2,
+ qkv_bias=True,
+ use_layer_scale=True,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super(DeiT3TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+
+ self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
+
+ self.attn = MultiheadAttention(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ attn_drop=attn_drop_rate,
+ proj_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ qkv_bias=qkv_bias,
+ use_layer_scale=use_layer_scale)
+
+ self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
+
+ self.ffn = DeiT3FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ num_fcs=num_fcs,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ use_layer_scale=use_layer_scale)
+
+ def init_weights(self):
+ super(DeiT3TransformerEncoderLayer, self).init_weights()
+ for m in self.ffn.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.bias, std=1e-6)
+
+ def forward(self, x):
+ x = x + self.attn(self.ln1(x))
+ x = self.ffn(self.ln1(x), identity=x)
+ return x
+
+
+@MODELS.register_module()
+class DeiT3(VisionTransformer):
+ """DeiT3 backbone.
+
+ A PyTorch implement of : `DeiT III: Revenge of the ViT
+ `_
+
+ The differences between DeiT3 & VisionTransformer:
+
+ 1. Use LayerScale.
+ 2. Concat cls token after adding pos_embed.
+
+ Args:
+ arch (str | dict): DeiT3 architecture. If use string,
+ choose from 'small', 'base', 'medium', 'large' and 'huge'.
+ If use dict, it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **num_layers** (int): The number of transformer encoder layers.
+ - **num_heads** (int): The number of heads in attention modules.
+ - **feedforward_channels** (int): The hidden dimensions in
+ feedforward modules.
+
+ Defaults to 'base'.
+ img_size (int | tuple): The expected input image shape. Because we
+ support dynamic input shape, just set the argument to the most
+ common input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ in_channels (int): The num of input channels. Defaults to 3.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ out_type (str): The type of output features. Please choose from
+
+ - ``"cls_token"``: The class token tensor with shape (B, C).
+ - ``"featmap"``: The feature map tensor from the patch tokens
+ with shape (B, C, H, W).
+ - ``"avg_featmap"``: The global averaged feature map tensor
+ with shape (B, C).
+ - ``"raw"``: The raw feature tensor includes patch tokens and
+ class tokens with shape (B, L, C).
+
+ Defaults to ``"cls_token"``.
+ with_cls_token (bool): Whether concatenating class token into image
+ tokens as transformer input. Defaults to True.
+ use_layer_scale (bool): Whether to use layer_scale in DeiT3.
+ Defaults to True.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(
+ ['s', 'small'], {
+ 'embed_dims': 384,
+ 'num_layers': 12,
+ 'num_heads': 6,
+ 'feedforward_channels': 1536,
+ }),
+ **dict.fromkeys(
+ ['m', 'medium'], {
+ 'embed_dims': 512,
+ 'num_layers': 12,
+ 'num_heads': 8,
+ 'feedforward_channels': 2048,
+ }),
+ **dict.fromkeys(
+ ['b', 'base'], {
+ 'embed_dims': 768,
+ 'num_layers': 12,
+ 'num_heads': 12,
+ 'feedforward_channels': 3072
+ }),
+ **dict.fromkeys(
+ ['l', 'large'], {
+ 'embed_dims': 1024,
+ 'num_layers': 24,
+ 'num_heads': 16,
+ 'feedforward_channels': 4096
+ }),
+ **dict.fromkeys(
+ ['h', 'huge'], {
+ 'embed_dims': 1280,
+ 'num_layers': 32,
+ 'num_heads': 16,
+ 'feedforward_channels': 5120
+ }),
+ }
+ num_extra_tokens = 1 # class token
+
+ def __init__(self,
+ arch='base',
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ out_indices=-1,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ qkv_bias=True,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ final_norm=True,
+ out_type='cls_token',
+ with_cls_token=True,
+ use_layer_scale=True,
+ interpolate_mode='bicubic',
+ patch_cfg=dict(),
+ layer_cfgs=dict(),
+ init_cfg=None):
+ super(VisionTransformer, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
+ }
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.num_layers = self.arch_settings['num_layers']
+ self.img_size = to_2tuple(img_size)
+
+ # Set patch embedding
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+
+ # Set out type
+ if out_type not in self.OUT_TYPES:
+ raise ValueError(f'Unsupported `out_type` {out_type}, please '
+ f'choose from {self.OUT_TYPES}')
+ self.out_type = out_type
+
+ # Set cls token
+ if with_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+ elif out_type != 'cls_token':
+ self.cls_token = None
+ self.num_extra_tokens = 0
+ else:
+ raise ValueError(
+ 'with_cls_token must be True when `out_type="cls_token"`.')
+
+ # Set position embedding
+ self.interpolate_mode = interpolate_mode
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, self.embed_dims))
+ self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
+
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_layers + index
+ assert 0 <= out_indices[i] <= self.num_layers, \
+ f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ # stochastic depth decay rule
+ dpr = np.linspace(0, drop_path_rate, self.num_layers)
+
+ self.layers = ModuleList()
+ if isinstance(layer_cfgs, dict):
+ layer_cfgs = [layer_cfgs] * self.num_layers
+ for i in range(self.num_layers):
+ _layer_cfg = dict(
+ embed_dims=self.embed_dims,
+ num_heads=self.arch_settings['num_heads'],
+ feedforward_channels=self.
+ arch_settings['feedforward_channels'],
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[i],
+ qkv_bias=qkv_bias,
+ norm_cfg=norm_cfg,
+ use_layer_scale=use_layer_scale)
+ _layer_cfg.update(layer_cfgs[i])
+ self.layers.append(DeiT3TransformerEncoderLayer(**_layer_cfg))
+
+ self.final_norm = final_norm
+ if final_norm:
+ self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
+
+ def forward(self, x):
+ B = x.shape[0]
+ x, patch_resolution = self.patch_embed(x)
+
+ x = x + resize_pos_embed(
+ self.pos_embed,
+ self.patch_resolution,
+ patch_resolution,
+ mode=self.interpolate_mode,
+ num_extra_tokens=0)
+ x = self.drop_after_pos(x)
+
+ if self.cls_token is not None:
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.ln1(x)
+
+ if i in self.out_indices:
+ outs.append(self._format_output(x, patch_resolution))
+
+ return tuple(outs)
+
+ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
+ name = prefix + 'pos_embed'
+ if name not in state_dict.keys():
+ return
+
+ ckpt_pos_embed_shape = state_dict[name].shape
+ if self.pos_embed.shape != ckpt_pos_embed_shape:
+ from mmengine.logging import MMLogger
+ logger = MMLogger.get_current_instance()
+ logger.info(
+ f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
+ f'to {self.pos_embed.shape}.')
+
+ ckpt_pos_embed_shape = to_2tuple(
+ int(np.sqrt(ckpt_pos_embed_shape[1])))
+ pos_embed_shape = self.patch_embed.init_out_size
+
+ state_dict[name] = resize_pos_embed(
+ state_dict[name],
+ ckpt_pos_embed_shape,
+ pos_embed_shape,
+ self.interpolate_mode,
+ num_extra_tokens=0, # The cls token adding is after pos_embed
+ )
diff --git a/mmpretrain/models/backbones/densenet.py b/mmpretrain/models/backbones/densenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f05302f9b84cd38c7c03701fc21ffd109c1620
--- /dev/null
+++ b/mmpretrain/models/backbones/densenet.py
@@ -0,0 +1,332 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from itertools import chain
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn.bricks import build_activation_layer, build_norm_layer
+from torch.jit.annotations import List
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class DenseLayer(BaseBackbone):
+ """DenseBlock layers."""
+
+ def __init__(self,
+ in_channels,
+ growth_rate,
+ bn_size,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ drop_rate=0.,
+ memory_efficient=False):
+ super(DenseLayer, self).__init__()
+
+ self.norm1 = build_norm_layer(norm_cfg, in_channels)[1]
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False)
+ self.act = build_activation_layer(act_cfg)
+ self.norm2 = build_norm_layer(norm_cfg, bn_size * growth_rate)[1]
+ self.conv2 = nn.Conv2d(
+ bn_size * growth_rate,
+ growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.drop_rate = float(drop_rate)
+ self.memory_efficient = memory_efficient
+
+ def bottleneck_fn(self, xs):
+ # type: (List[torch.Tensor]) -> torch.Tensor
+ concated_features = torch.cat(xs, 1)
+ bottleneck_output = self.conv1(
+ self.act(self.norm1(concated_features))) # noqa: T484
+ return bottleneck_output
+
+ # todo: rewrite when torchscript supports any
+ def any_requires_grad(self, x):
+ # type: (List[torch.Tensor]) -> bool
+ for tensor in x:
+ if tensor.requires_grad:
+ return True
+ return False
+
+ # This decorator indicates to the compiler that a function or method
+ # should be ignored and replaced with the raising of an exception.
+ # Here this function is incompatible with torchscript.
+ @torch.jit.unused # noqa: T484
+ def call_checkpoint_bottleneck(self, x):
+ # type: (List[torch.Tensor]) -> torch.Tensor
+ def closure(*xs):
+ return self.bottleneck_fn(xs)
+
+ # Here use torch.utils.checkpoint to rerun a forward-pass during
+ # backward in bottleneck to save memories.
+ return cp.checkpoint(closure, *x)
+
+ def forward(self, x): # noqa: F811
+ # type: (List[torch.Tensor]) -> torch.Tensor
+ # assert input features is a list of Tensor
+ assert isinstance(x, list)
+
+ if self.memory_efficient and self.any_requires_grad(x):
+ if torch.jit.is_scripting():
+ raise Exception('Memory Efficient not supported in JIT')
+ bottleneck_output = self.call_checkpoint_bottleneck(x)
+ else:
+ bottleneck_output = self.bottleneck_fn(x)
+
+ new_features = self.conv2(self.act(self.norm2(bottleneck_output)))
+ if self.drop_rate > 0:
+ new_features = F.dropout(
+ new_features, p=self.drop_rate, training=self.training)
+ return new_features
+
+
+class DenseBlock(nn.Module):
+ """DenseNet Blocks."""
+
+ def __init__(self,
+ num_layers,
+ in_channels,
+ bn_size,
+ growth_rate,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ drop_rate=0.,
+ memory_efficient=False):
+ super(DenseBlock, self).__init__()
+ self.block = nn.ModuleList([
+ DenseLayer(
+ in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ drop_rate=drop_rate,
+ memory_efficient=memory_efficient) for i in range(num_layers)
+ ])
+
+ def forward(self, init_features):
+ features = [init_features]
+ for layer in self.block:
+ new_features = layer(features)
+ features.append(new_features)
+ return torch.cat(features, 1)
+
+
+class DenseTransition(nn.Sequential):
+ """DenseNet Transition Layers."""
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU')):
+ super(DenseTransition, self).__init__()
+ self.add_module('norm', build_norm_layer(norm_cfg, in_channels)[1])
+ self.add_module('act', build_activation_layer(act_cfg))
+ self.add_module(
+ 'conv',
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1,
+ bias=False))
+ self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
+
+
+@MODELS.register_module()
+class DenseNet(BaseBackbone):
+ """DenseNet.
+
+ A PyTorch implementation of : `Densely Connected Convolutional Networks
+ `_
+
+ Modified from the `official repo
+ `_
+ and `pytorch
+ `_.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``DenseNet.arch_settings``. And if dict, it
+ should include the following two keys:
+
+ - growth_rate (int): Each layer of DenseBlock produce `k` feature
+ maps. Here refers `k` as the growth rate of the network.
+ - depths (list[int]): Number of repeated layers in each DenseBlock.
+ - init_channels (int): The output channels of stem layers.
+
+ Defaults to '121'.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ bn_size (int): Refers to channel expansion parameter of 1x1
+ convolution layer. Defaults to 4.
+ drop_rate (float): Drop rate of Dropout Layer. Defaults to 0.
+ compression_factor (float): The reduction rate of transition layers.
+ Defaults to 0.5.
+ memory_efficient (bool): If True, uses checkpointing. Much more memory
+ efficient, but slower. Defaults to False.
+ See `"paper" `_.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation after each convolution.
+ Defaults to ``dict(type='ReLU')``.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+ arch_settings = {
+ '121': {
+ 'growth_rate': 32,
+ 'depths': [6, 12, 24, 16],
+ 'init_channels': 64,
+ },
+ '169': {
+ 'growth_rate': 32,
+ 'depths': [6, 12, 32, 32],
+ 'init_channels': 64,
+ },
+ '201': {
+ 'growth_rate': 32,
+ 'depths': [6, 12, 48, 32],
+ 'init_channels': 64,
+ },
+ '161': {
+ 'growth_rate': 48,
+ 'depths': [6, 12, 36, 24],
+ 'init_channels': 96,
+ },
+ }
+
+ def __init__(self,
+ arch='121',
+ in_channels=3,
+ bn_size=4,
+ drop_rate=0,
+ compression_factor=0.5,
+ memory_efficient=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ out_indices=-1,
+ frozen_stages=0,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ essential_keys = {'growth_rate', 'depths', 'init_channels'}
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+
+ self.growth_rate = arch['growth_rate']
+ self.depths = arch['depths']
+ self.init_channels = arch['init_channels']
+ self.act = build_activation_layer(act_cfg)
+
+ self.num_stages = len(self.depths)
+
+ # check out indices and frozen stages
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_stages + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # Set stem layers
+ self.stem = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ self.init_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False),
+ build_norm_layer(norm_cfg, self.init_channels)[1], self.act,
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+
+ # Repetitions of DenseNet Blocks
+ self.stages = nn.ModuleList()
+ self.transitions = nn.ModuleList()
+
+ channels = self.init_channels
+ for i in range(self.num_stages):
+ depth = self.depths[i]
+
+ stage = DenseBlock(
+ num_layers=depth,
+ in_channels=channels,
+ bn_size=bn_size,
+ growth_rate=self.growth_rate,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ drop_rate=drop_rate,
+ memory_efficient=memory_efficient)
+ self.stages.append(stage)
+ channels += depth * self.growth_rate
+
+ if i != self.num_stages - 1:
+ transition = DenseTransition(
+ in_channels=channels,
+ out_channels=math.floor(channels * compression_factor),
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ )
+ channels = math.floor(channels * compression_factor)
+ else:
+ # Final layers after dense block is just bn with act.
+ # Unlike the paper, the original repo also put this in
+ # transition layer, whereas torchvision take this out.
+ # We reckon this as transition layer here.
+ transition = nn.Sequential(
+ build_norm_layer(norm_cfg, channels)[1],
+ self.act,
+ )
+ self.transitions.append(transition)
+
+ self._freeze_stages()
+
+ def forward(self, x):
+ x = self.stem(x)
+ outs = []
+ for i in range(self.num_stages):
+ x = self.stages[i](x)
+ x = self.transitions[i](x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ downsample_layer = self.transitions[i]
+ stage = self.stages[i]
+ downsample_layer.eval()
+ stage.eval()
+ for param in chain(downsample_layer.parameters(),
+ stage.parameters()):
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(DenseNet, self).train(mode)
+ self._freeze_stages()
diff --git a/mmpretrain/models/backbones/edgenext.py b/mmpretrain/models/backbones/edgenext.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4e768e7561eb49da3603f4394faaebed7c9251
--- /dev/null
+++ b/mmpretrain/models/backbones/edgenext.py
@@ -0,0 +1,398 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from itertools import chain
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule, ModuleList, Sequential
+
+from mmpretrain.registry import MODELS
+from ..utils import (ChannelMultiheadAttention, PositionEncodingFourier,
+ build_norm_layer)
+from .base_backbone import BaseBackbone
+from .convnext import ConvNeXtBlock
+
+
+class SDTAEncoder(BaseModule):
+ """A PyTorch implementation of split depth-wise transpose attention (SDTA)
+ encoder.
+
+ Inspiration from
+ https://github.com/mmaaz60/EdgeNeXt
+ Args:
+ in_channel (int): Number of input channels.
+ drop_path_rate (float): Stochastic depth dropout rate.
+ Defaults to 0.
+ layer_scale_init_value (float): Initial value of layer scale.
+ Defaults to 1e-6.
+ mlp_ratio (int): Number of channels ratio in the MLP.
+ Defaults to 4.
+ use_pos_emb (bool): Whether to use position encoding.
+ Defaults to True.
+ num_heads (int): Number of heads in the multihead attention.
+ Defaults to 8.
+ qkv_bias (bool): Whether to use bias in the multihead attention.
+ Defaults to True.
+ attn_drop (float): Dropout rate of the attention.
+ Defaults to 0.
+ proj_drop (float): Dropout rate of the projection.
+ Defaults to 0.
+ layer_scale_init_value (float): Initial value of layer scale.
+ Defaults to 1e-6.
+ norm_cfg (dict): Dictionary to construct normalization layer.
+ Defaults to ``dict(type='LN')``.
+ act_cfg (dict): Dictionary to construct activation layer.
+ Defaults to ``dict(type='GELU')``.
+ scales (int): Number of scales. Default to 1.
+ """
+
+ def __init__(self,
+ in_channel,
+ drop_path_rate=0.,
+ layer_scale_init_value=1e-6,
+ mlp_ratio=4,
+ use_pos_emb=True,
+ num_heads=8,
+ qkv_bias=True,
+ attn_drop=0.,
+ proj_drop=0.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ scales=1,
+ init_cfg=None):
+ super(SDTAEncoder, self).__init__(init_cfg=init_cfg)
+ conv_channels = max(
+ int(math.ceil(in_channel / scales)),
+ int(math.floor(in_channel // scales)))
+ self.conv_channels = conv_channels
+ self.num_convs = scales if scales == 1 else scales - 1
+
+ self.conv_modules = ModuleList()
+ for i in range(self.num_convs):
+ self.conv_modules.append(
+ nn.Conv2d(
+ conv_channels,
+ conv_channels,
+ kernel_size=3,
+ padding=1,
+ groups=conv_channels))
+
+ self.pos_embed = PositionEncodingFourier(
+ embed_dims=in_channel) if use_pos_emb else None
+
+ self.norm_csa = build_norm_layer(norm_cfg, in_channel)
+ self.gamma_csa = nn.Parameter(
+ layer_scale_init_value * torch.ones(in_channel),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.csa = ChannelMultiheadAttention(
+ embed_dims=in_channel,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop)
+
+ self.norm = build_norm_layer(norm_cfg, in_channel)
+ self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel)
+ self.act = MODELS.build(act_cfg)
+ self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel)
+ self.gamma = nn.Parameter(
+ layer_scale_init_value * torch.ones(in_channel),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+ shortcut = x
+ spx = torch.split(x, self.conv_channels, dim=1)
+ for i in range(self.num_convs):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.conv_modules[i](sp)
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ x = torch.cat((out, spx[self.num_convs]), 1)
+
+ # Channel Self-attention
+ B, C, H, W = x.shape
+ x = x.reshape(B, C, H * W).permute(0, 2, 1)
+ if self.pos_embed:
+ pos_encoding = self.pos_embed((B, H, W))
+ pos_encoding = pos_encoding.reshape(B, -1,
+ x.shape[1]).permute(0, 2, 1)
+ x += pos_encoding
+
+ x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x)))
+ x = x.reshape(B, H, W, C)
+
+ # Inverted Bottleneck
+ x = self.norm(x)
+ x = self.pointwise_conv1(x)
+ x = self.act(x)
+ x = self.pointwise_conv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+
+ x = shortcut + self.drop_path(x)
+
+ return x
+
+
+@MODELS.register_module()
+class EdgeNeXt(BaseBackbone):
+ """EdgeNeXt.
+
+ A PyTorch implementation of: `EdgeNeXt: Efficiently Amalgamated
+ CNN-Transformer Architecture for Mobile Vision Applications
+ `_
+
+ Inspiration from
+ https://github.com/mmaaz60/EdgeNeXt
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architectures in ``EdgeNeXt.arch_settings``.
+ And if dict, it should include the following keys:
+
+ - channels (list[int]): The number of channels at each stage.
+ - depths (list[int]): The number of blocks at each stage.
+ - num_heads (list[int]): The number of heads at each stage.
+
+ Defaults to 'xxsmall'.
+ in_channels (int): The number of input channels.
+ Defaults to 3.
+ global_blocks (list[int]): The number of global blocks.
+ Defaults to [0, 1, 1, 1].
+ global_block_type (list[str]): The type of global blocks.
+ Defaults to ['None', 'SDTA', 'SDTA', 'SDTA'].
+ drop_path_rate (float): Stochastic depth dropout rate.
+ Defaults to 0.
+ layer_scale_init_value (float): Initial value of layer scale.
+ Defaults to 1e-6.
+ linear_pw_conv (bool): Whether to use linear layer to do pointwise
+ convolution. Defaults to False.
+ mlp_ratio (int): The number of channel ratio in MLP layers.
+ Defaults to 4.
+ conv_kernel_size (list[int]): The kernel size of convolutional layers
+ at each stage. Defaults to [3, 5, 7, 9].
+ use_pos_embd_csa (list[bool]): Whether to use positional embedding in
+ Channel Self-Attention. Defaults to [False, True, False, False].
+ use_pos_emebd_global (bool): Whether to use positional embedding for
+ whole network. Defaults to False.
+ d2_scales (list[int]): The number of channel groups used for SDTA at
+ each stage. Defaults to [2, 2, 3, 4].
+ norm_cfg (dict): The config of normalization layer.
+ Defaults to ``dict(type='LN2d', eps=1e-6)``.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ gap_before_final_norm (bool): Whether to globally average the feature
+ map before the final norm layer. Defaults to True.
+ act_cfg (dict): The config of activation layer.
+ Defaults to ``dict(type='GELU')``.
+ init_cfg (dict, optional): Config for initialization.
+ Defaults to None.
+ """
+ arch_settings = {
+ 'xxsmall': { # parameters: 1.3M
+ 'channels': [24, 48, 88, 168],
+ 'depths': [2, 2, 6, 2],
+ 'num_heads': [4, 4, 4, 4]
+ },
+ 'xsmall': { # parameters: 2.3M
+ 'channels': [32, 64, 100, 192],
+ 'depths': [3, 3, 9, 3],
+ 'num_heads': [4, 4, 4, 4]
+ },
+ 'small': { # parameters: 5.6M
+ 'channels': [48, 96, 160, 304],
+ 'depths': [3, 3, 9, 3],
+ 'num_heads': [8, 8, 8, 8]
+ },
+ 'base': { # parameters: 18.51M
+ 'channels': [80, 160, 288, 584],
+ 'depths': [3, 3, 9, 3],
+ 'num_heads': [8, 8, 8, 8]
+ },
+ }
+
+ def __init__(self,
+ arch='xxsmall',
+ in_channels=3,
+ global_blocks=[0, 1, 1, 1],
+ global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
+ drop_path_rate=0.,
+ layer_scale_init_value=1e-6,
+ linear_pw_conv=True,
+ mlp_ratio=4,
+ conv_kernel_sizes=[3, 5, 7, 9],
+ use_pos_embd_csa=[False, True, False, False],
+ use_pos_embd_global=False,
+ d2_scales=[2, 2, 3, 4],
+ norm_cfg=dict(type='LN2d', eps=1e-6),
+ out_indices=-1,
+ frozen_stages=0,
+ gap_before_final_norm=True,
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super(EdgeNeXt, self).__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in self.arch_settings, \
+ f'Arch {arch} is not in default archs ' \
+ f'{set(self.arch_settings)}'
+ self.arch_settings = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ essential_keys = {'channels', 'depths', 'num_heads'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.channels = self.arch_settings['channels']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+ self.num_layers = len(self.depths)
+ self.use_pos_embd_global = use_pos_embd_global
+
+ for g in global_block_type:
+ assert g in ['None',
+ 'SDTA'], f'Global block type {g} is not supported'
+
+ self.num_stages = len(self.depths)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = 4 + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ self.frozen_stages = frozen_stages
+ self.gap_before_final_norm = gap_before_final_norm
+
+ if self.use_pos_embd_global:
+ self.pos_embed = PositionEncodingFourier(
+ embed_dims=self.channels[0])
+ else:
+ self.pos_embed = None
+
+ # stochastic depth decay rule
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(self.depths))
+ ]
+
+ self.downsample_layers = ModuleList()
+ stem = nn.Sequential(
+ nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4),
+ build_norm_layer(norm_cfg, self.channels[0]),
+ )
+ self.downsample_layers.append(stem)
+
+ self.stages = ModuleList()
+ block_idx = 0
+ for i in range(self.num_stages):
+ depth = self.depths[i]
+ channels = self.channels[i]
+
+ if i >= 1:
+ downsample_layer = nn.Sequential(
+ build_norm_layer(norm_cfg, self.channels[i - 1]),
+ nn.Conv2d(
+ self.channels[i - 1],
+ channels,
+ kernel_size=2,
+ stride=2,
+ ))
+ self.downsample_layers.append(downsample_layer)
+
+ stage_blocks = []
+ for j in range(depth):
+ if j > depth - global_blocks[i] - 1:
+ stage_blocks.append(
+ SDTAEncoder(
+ in_channel=channels,
+ drop_path_rate=dpr[block_idx + j],
+ mlp_ratio=mlp_ratio,
+ scales=d2_scales[i],
+ use_pos_emb=use_pos_embd_csa[i],
+ num_heads=self.num_heads[i],
+ ))
+ else:
+ dw_conv_cfg = dict(
+ kernel_size=conv_kernel_sizes[i],
+ padding=conv_kernel_sizes[i] // 2,
+ )
+ stage_blocks.append(
+ ConvNeXtBlock(
+ in_channels=channels,
+ dw_conv_cfg=dw_conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ linear_pw_conv=linear_pw_conv,
+ drop_path_rate=dpr[block_idx + j],
+ layer_scale_init_value=layer_scale_init_value,
+ ))
+ block_idx += depth
+
+ stage_blocks = Sequential(*stage_blocks)
+ self.stages.append(stage_blocks)
+
+ if i in self.out_indices:
+ out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \
+ else norm_cfg
+ norm_layer = build_norm_layer(out_norm_cfg, channels)
+ self.add_module(f'norm{i}', norm_layer)
+
+ def init_weights(self) -> None:
+ # TODO: need to be implemented in the future
+ return super().init_weights()
+
+ def forward(self, x):
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x = self.downsample_layers[i](x)
+ x = stage(x)
+ if self.pos_embed and i == 0:
+ B, _, H, W = x.shape
+ x += self.pos_embed((B, H, W))
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ if self.gap_before_final_norm:
+ gap = x.mean([-2, -1], keepdim=True)
+ outs.append(norm_layer(gap.flatten(1)))
+ else:
+ # The output of LayerNorm2d may be discontiguous, which
+ # may cause some problem in the downstream tasks
+ outs.append(norm_layer(x).contiguous())
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ downsample_layer = self.downsample_layers[i]
+ stage = self.stages[i]
+ downsample_layer.eval()
+ stage.eval()
+ for param in chain(downsample_layer.parameters(),
+ stage.parameters()):
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EdgeNeXt, self).train(mode)
+ self._freeze_stages()
diff --git a/mmpretrain/models/backbones/efficientformer.py b/mmpretrain/models/backbones/efficientformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2525c8faaa745ff5404e91004421f2360dd1c41
--- /dev/null
+++ b/mmpretrain/models/backbones/efficientformer.py
@@ -0,0 +1,606 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+from typing import Optional, Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
+ build_norm_layer)
+from mmengine.model import BaseModule, ModuleList, Sequential
+
+from mmpretrain.registry import MODELS
+from ..utils import LayerScale
+from .base_backbone import BaseBackbone
+from .poolformer import Pooling
+
+
+class AttentionWithBias(BaseModule):
+ """Multi-head Attention Module with attention_bias.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads. Defaults to 8.
+ key_dim (int): The dimension of q, k. Defaults to 32.
+ attn_ratio (float): The dimension of v equals to
+ ``key_dim * attn_ratio``. Defaults to 4.
+ resolution (int): The height and width of attention_bias.
+ Defaults to 7.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads=8,
+ key_dim=32,
+ attn_ratio=4.,
+ resolution=7,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.attn_ratio = attn_ratio
+ self.key_dim = key_dim
+ self.nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ h = self.dh + self.nh_kd * 2
+ self.qkv = nn.Linear(embed_dims, h)
+ self.proj = nn.Linear(self.dh, embed_dims)
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = nn.Parameter(
+ torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs',
+ torch.LongTensor(idxs).view(N, N))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ """change the mode of model."""
+ super().train(mode)
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x):
+ """forward function.
+
+ Args:
+ x (tensor): input features with shape of (B, N, C)
+ """
+ B, N, _ = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)
+
+ attn = ((q @ k.transpose(-2, -1)) * self.scale +
+ (self.attention_biases[:, self.attention_bias_idxs]
+ if self.training else self.ab))
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+ x = self.proj(x)
+ return x
+
+
+class Flat(nn.Module):
+ """Flat the input from (B, C, H, W) to (B, H*W, C)."""
+
+ def __init__(self, ):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor):
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class LinearMlp(BaseModule):
+ """Mlp implemented with linear.
+
+ The shape of input and output tensor are (B, N, C).
+
+ Args:
+ in_features (int): Dimension of input features.
+ hidden_features (int): Dimension of hidden features.
+ out_features (int): Dimension of output features.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.drop1 = nn.Dropout(drop)
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop2 = nn.Dropout(drop)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape (B, N, C).
+
+ Returns:
+ torch.Tensor: output tensor with shape (B, N, C).
+ """
+ x = self.drop1(self.act(self.fc1(x)))
+ x = self.drop2(self.fc2(x))
+ return x
+
+
+class ConvMlp(BaseModule):
+ """Mlp implemented with 1*1 convolutions.
+
+ Args:
+ in_features (int): Dimension of input features.
+ hidden_features (int): Dimension of hidden features.
+ out_features (int): Dimension of output features.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
+ self.norm2 = build_norm_layer(norm_cfg, out_features)[1]
+
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape (B, C, H, W).
+
+ Returns:
+ torch.Tensor: output tensor with shape (B, C, H, W).
+ """
+
+ x = self.act(self.norm1(self.fc1(x)))
+ x = self.drop(x)
+ x = self.norm2(self.fc2(x))
+ x = self.drop(x)
+ return x
+
+
+class Meta3D(BaseModule):
+ """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
+ (B, N, C)."""
+
+ def __init__(self,
+ dim,
+ mlp_ratio=4.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ drop_path=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
+ self.token_mixer = AttentionWithBias(dim)
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = LinearMlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
+ else nn.Identity()
+ if use_layer_scale:
+ self.ls1 = LayerScale(dim)
+ self.ls2 = LayerScale(dim)
+ else:
+ self.ls1, self.ls2 = nn.Identity(), nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class Meta4D(BaseModule):
+ """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
+ (B, C, H, W)."""
+
+ def __init__(self,
+ dim,
+ pool_size=3,
+ mlp_ratio=4.,
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ drop_path=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.token_mixer = Pooling(pool_size=pool_size)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ConvMlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
+ else nn.Identity()
+ if use_layer_scale:
+ self.ls1 = LayerScale(dim, data_format='channels_first')
+ self.ls2 = LayerScale(dim, data_format='channels_first')
+ else:
+ self.ls1, self.ls2 = nn.Identity(), nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1(self.token_mixer(x)))
+ x = x + self.drop_path(self.ls2(self.mlp(x)))
+ return x
+
+
+def basic_blocks(in_channels,
+ out_channels,
+ index,
+ layers,
+ pool_size=3,
+ mlp_ratio=4.,
+ act_cfg=dict(type='GELU'),
+ drop_rate=.0,
+ drop_path_rate=0.,
+ use_layer_scale=True,
+ vit_num=1,
+ has_downsamper=False):
+ """generate EfficientFormer blocks for a stage."""
+ blocks = []
+ if has_downsamper:
+ blocks.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=None))
+ if index == 3 and vit_num == layers[index]:
+ blocks.append(Flat())
+ for block_idx in range(layers[index]):
+ block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
+ sum(layers) - 1)
+ if index == 3 and layers[index] - block_idx <= vit_num:
+ blocks.append(
+ Meta3D(
+ out_channels,
+ mlp_ratio=mlp_ratio,
+ act_cfg=act_cfg,
+ drop=drop_rate,
+ drop_path=block_dpr,
+ use_layer_scale=use_layer_scale,
+ ))
+ else:
+ blocks.append(
+ Meta4D(
+ out_channels,
+ pool_size=pool_size,
+ act_cfg=act_cfg,
+ drop=drop_rate,
+ drop_path=block_dpr,
+ use_layer_scale=use_layer_scale))
+ if index == 3 and layers[index] - block_idx - 1 == vit_num:
+ blocks.append(Flat())
+ blocks = nn.Sequential(*blocks)
+ return blocks
+
+
+@MODELS.register_module()
+class EfficientFormer(BaseBackbone):
+ """EfficientFormer.
+
+ A PyTorch implementation of EfficientFormer introduced by:
+ `EfficientFormer: Vision Transformers at MobileNet Speed `_
+
+ Modified from the `official repo
+ `.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``EfficientFormer.arch_settings``. And if dict,
+ it should include the following 4 keys:
+
+ - layers (list[int]): Number of blocks at each stage.
+ - embed_dims (list[int]): The number of channels at each stage.
+ - downsamples (list[int]): Has downsample or not in the four stages.
+ - vit_num (int): The num of vit blocks in the last stage.
+
+ Defaults to 'l1'.
+
+ in_channels (int): The num of input channels. Defaults to 3.
+ pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3.
+ mlp_ratios (int): The dimension ratio of multi-head attention mechanism
+ in ``Meta4D`` blocks. Defaults to 3.
+ reshape_last_feat (bool): Whether to reshape the feature map from
+ (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num``
+ in ``arch`` is not 0. Defaults to False. Usually set to True
+ in downstream tasks.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to -1.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop_rate (float): Dropout rate. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer
+ block. Defaults to True.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+
+ Example:
+ >>> from mmpretrain.models import EfficientFormer
+ >>> import torch
+ >>> inputs = torch.rand((1, 3, 224, 224))
+ >>> # build EfficientFormer backbone for classification task
+ >>> model = EfficientFormer(arch="l1")
+ >>> model.eval()
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 448, 49)
+ >>> # build EfficientFormer backbone for downstream task
+ >>> model = EfficientFormer(
+ >>> arch="l3",
+ >>> out_indices=(0, 1, 2, 3),
+ >>> reshape_last_feat=True)
+ >>> model.eval()
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 56, 56)
+ (1, 128, 28, 28)
+ (1, 320, 14, 14)
+ (1, 512, 7, 7)
+ """ # noqa: E501
+
+ # --layers: [x,x,x,x], numbers of layers for the four stages
+ # --embed_dims: [x,x,x,x], embedding dims for the four stages
+ # --downsamples: [x,x,x,x], has downsample or not in the four stages
+ # --vit_num:(int), the num of vit blocks in the last stage
+ arch_settings = {
+ 'l1': {
+ 'layers': [3, 2, 6, 4],
+ 'embed_dims': [48, 96, 224, 448],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 1,
+ },
+ 'l3': {
+ 'layers': [4, 4, 12, 6],
+ 'embed_dims': [64, 128, 320, 512],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 4,
+ },
+ 'l7': {
+ 'layers': [6, 6, 18, 8],
+ 'embed_dims': [96, 192, 384, 768],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 8,
+ },
+ }
+
+ def __init__(self,
+ arch='l1',
+ in_channels=3,
+ pool_size=3,
+ mlp_ratios=4,
+ reshape_last_feat=False,
+ out_indices=-1,
+ frozen_stages=-1,
+ act_cfg=dict(type='GELU'),
+ drop_rate=0.,
+ drop_path_rate=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.num_extra_tokens = 0 # no cls_token, no dist_token
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ default_keys = set(self.arch_settings['l1'].keys())
+ assert set(arch.keys()) == default_keys, \
+ f'The arch dict must have {default_keys}, ' \
+ f'but got {list(arch.keys())}.'
+
+ self.layers = arch['layers']
+ self.embed_dims = arch['embed_dims']
+ self.downsamples = arch['downsamples']
+ assert isinstance(self.layers, list) and isinstance(
+ self.embed_dims, list) and isinstance(self.downsamples, list)
+ assert len(self.layers) == len(self.embed_dims) == len(
+ self.downsamples)
+
+ self.vit_num = arch['vit_num']
+ self.reshape_last_feat = reshape_last_feat
+
+ assert self.vit_num >= 0, "'vit_num' must be an integer " \
+ 'greater than or equal to 0.'
+ assert self.vit_num <= self.layers[-1], (
+ "'vit_num' must be an integer smaller than layer number")
+
+ self._make_stem(in_channels, self.embed_dims[0])
+
+ # set the main block in network
+ network = []
+ for i in range(len(self.layers)):
+ if i != 0:
+ in_channels = self.embed_dims[i - 1]
+ else:
+ in_channels = self.embed_dims[i]
+ out_channels = self.embed_dims[i]
+ stage = basic_blocks(
+ in_channels,
+ out_channels,
+ i,
+ self.layers,
+ pool_size=pool_size,
+ mlp_ratio=mlp_ratios,
+ act_cfg=act_cfg,
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ vit_num=self.vit_num,
+ use_layer_scale=use_layer_scale,
+ has_downsamper=self.downsamples[i])
+ network.append(stage)
+
+ self.network = ModuleList(network)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = 4 + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+
+ self.out_indices = out_indices
+ for i_layer in self.out_indices:
+ if not self.reshape_last_feat and \
+ i_layer == 3 and self.vit_num > 0:
+ layer = build_norm_layer(
+ dict(type='LN'), self.embed_dims[i_layer])[1]
+ else:
+ # use GN with 1 group as channel-first LN2D
+ layer = build_norm_layer(
+ dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1]
+
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self.frozen_stages = frozen_stages
+ self._freeze_stages()
+
+ def _make_stem(self, in_channels: int, stem_channels: int):
+ """make 2-ConvBNReLu stem layer."""
+ self.patch_embed = Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ inplace=True))
+
+ def forward_tokens(self, x):
+ outs = []
+ for idx, block in enumerate(self.network):
+ if idx == len(self.network) - 1:
+ N, _, H, W = x.shape
+ if self.downsamples[idx]:
+ H, W = H // 2, W // 2
+ x = block(x)
+ if idx in self.out_indices:
+ norm_layer = getattr(self, f'norm{idx}')
+
+ if idx == len(self.network) - 1 and x.dim() == 3:
+ # when ``vit-num`` > 0 and in the last stage,
+ # if `self.reshape_last_feat`` is True, reshape the
+ # features to `BCHW` format before the final normalization.
+ # if `self.reshape_last_feat`` is False, do
+ # normalization directly and permute the features to `BCN`.
+ if self.reshape_last_feat:
+ x = x.permute((0, 2, 1)).reshape(N, -1, H, W)
+ x_out = norm_layer(x)
+ else:
+ x_out = norm_layer(x).permute((0, 2, 1))
+ else:
+ x_out = norm_layer(x)
+
+ outs.append(x_out.contiguous())
+ return tuple(outs)
+
+ def forward(self, x):
+ # input embedding
+ x = self.patch_embed(x)
+ # through stages
+ x = self.forward_tokens(x)
+ return x
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(self.frozen_stages):
+ # Include both block and downsample layer.
+ module = self.network[i]
+ module.eval()
+ for param in module.parameters():
+ param.requires_grad = False
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EfficientFormer, self).train(mode)
+ self._freeze_stages()
diff --git a/mmpretrain/models/backbones/efficientnet.py b/mmpretrain/models/backbones/efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec7ee81186610f7adb8af92325471d794509ddc
--- /dev/null
+++ b/mmpretrain/models/backbones/efficientnet.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn.bricks import ConvModule, DropPath
+from mmengine.model import BaseModule, Sequential
+
+from mmpretrain.models.backbones.base_backbone import BaseBackbone
+from mmpretrain.models.utils import InvertedResidual, SELayer, make_divisible
+from mmpretrain.registry import MODELS
+
+
+class EdgeResidual(BaseModule):
+ """Edge Residual Block.
+
+ Args:
+ in_channels (int): The input channels of this module.
+ out_channels (int): The output channels of this module.
+ mid_channels (int): The input channels of the second convolution.
+ kernel_size (int): The kernel size of the first convolution.
+ Defaults to 3.
+ stride (int): The stride of the first convolution. Defaults to 1.
+ se_cfg (dict, optional): Config dict for se layer. Defaults to None,
+ which means no se layer.
+ with_residual (bool): Use residual connection. Defaults to True.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='ReLU')``.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict | list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_residual=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ drop_path_rate=0.,
+ with_cp=False,
+ init_cfg=None):
+ super(EdgeResidual, self).__init__(init_cfg=init_cfg)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.with_se = se_cfg is not None
+ self.with_residual = (
+ stride == 1 and in_channels == out_channels and with_residual)
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+
+ self.conv1 = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.conv2 = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+ out = self.conv1(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.conv2(out)
+
+ if self.with_residual:
+ return x + self.drop_path(out)
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+def model_scaling(layer_setting, arch_setting):
+ """Scaling operation to the layer's parameters according to the
+ arch_setting."""
+ # scale width
+ new_layer_setting = copy.deepcopy(layer_setting)
+ for layer_cfg in new_layer_setting:
+ for block_cfg in layer_cfg:
+ block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8)
+
+ # scale depth
+ split_layer_setting = [new_layer_setting[0]]
+ for layer_cfg in new_layer_setting[1:-1]:
+ tmp_index = [0]
+ for i in range(len(layer_cfg) - 1):
+ if layer_cfg[i + 1][1] != layer_cfg[i][1]:
+ tmp_index.append(i + 1)
+ tmp_index.append(len(layer_cfg))
+ for i in range(len(tmp_index) - 1):
+ split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i +
+ 1]])
+ split_layer_setting.append(new_layer_setting[-1])
+
+ num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]]
+ new_layers = [
+ int(math.ceil(arch_setting[1] * num)) for num in num_of_layers
+ ]
+
+ merge_layer_setting = [split_layer_setting[0]]
+ for i, layer_cfg in enumerate(split_layer_setting[1:-1]):
+ if new_layers[i] <= num_of_layers[i]:
+ tmp_layer_cfg = layer_cfg[:new_layers[i]]
+ else:
+ tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * (
+ new_layers[i] - num_of_layers[i])
+ if tmp_layer_cfg[0][3] == 1 and i != 0:
+ merge_layer_setting[-1] += tmp_layer_cfg.copy()
+ else:
+ merge_layer_setting.append(tmp_layer_cfg.copy())
+ merge_layer_setting.append(split_layer_setting[-1])
+
+ return merge_layer_setting
+
+
+@MODELS.register_module()
+class EfficientNet(BaseBackbone):
+ """EfficientNet backbone.
+
+ Args:
+ arch (str): Architecture of efficientnet. Defaults to b0.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to (6, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to dict(type='Swish').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ """
+
+ # Parameters to build layers.
+ # 'b' represents the architecture of normal EfficientNet family includes
+ # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'.
+ # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es',
+ # 'em', 'el'.
+ # 6 parameters are needed to construct a layer, From left to right:
+ # - kernel_size: The kernel size of the block
+ # - out_channel: The number of out_channels of the block
+ # - se_ratio: The sequeeze ratio of SELayer.
+ # - stride: The stride of the block
+ # - expand_ratio: The expand_ratio of the mid_channels
+ # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual
+ layer_settings = {
+ 'b': [[[3, 32, 0, 2, 0, -1]],
+ [[3, 16, 4, 1, 1, 0]],
+ [[3, 24, 4, 2, 6, 0],
+ [3, 24, 4, 1, 6, 0]],
+ [[5, 40, 4, 2, 6, 0],
+ [5, 40, 4, 1, 6, 0]],
+ [[3, 80, 4, 2, 6, 0],
+ [3, 80, 4, 1, 6, 0],
+ [3, 80, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0]],
+ [[5, 192, 4, 2, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [3, 320, 4, 1, 6, 0]],
+ [[1, 1280, 0, 1, 0, -1]]
+ ],
+ 'e': [[[3, 32, 0, 2, 0, -1]],
+ [[3, 24, 0, 1, 3, 1]],
+ [[3, 32, 0, 2, 8, 1],
+ [3, 32, 0, 1, 8, 1]],
+ [[3, 48, 0, 2, 8, 1],
+ [3, 48, 0, 1, 8, 1],
+ [3, 48, 0, 1, 8, 1],
+ [3, 48, 0, 1, 8, 1]],
+ [[5, 96, 0, 2, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0]],
+ [[5, 192, 0, 2, 8, 0],
+ [5, 192, 0, 1, 8, 0]],
+ [[1, 1280, 0, 1, 0, -1]]
+ ]
+ } # yapf: disable
+
+ # Parameters to build different kinds of architecture.
+ # From left to right: scaling factor for width, scaling factor for depth,
+ # resolution.
+ arch_settings = {
+ 'b0': (1.0, 1.0, 224),
+ 'b1': (1.0, 1.1, 240),
+ 'b2': (1.1, 1.2, 260),
+ 'b3': (1.2, 1.4, 300),
+ 'b4': (1.4, 1.8, 380),
+ 'b5': (1.6, 2.2, 456),
+ 'b6': (1.8, 2.6, 528),
+ 'b7': (2.0, 3.1, 600),
+ 'b8': (2.2, 3.6, 672),
+ 'l2': (4.3, 5.3, 800),
+ 'es': (1.0, 1.0, 224),
+ 'em': (1.0, 1.1, 240),
+ 'el': (1.2, 1.4, 300)
+ }
+
+ def __init__(self,
+ arch='b0',
+ drop_path_rate=0.,
+ out_indices=(6, ),
+ frozen_stages=0,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=dict(type='BN', eps=1e-3),
+ act_cfg=dict(type='Swish'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ layer=['_BatchNorm', 'GroupNorm'],
+ val=1)
+ ]):
+ super(EfficientNet, self).__init__(init_cfg)
+ assert arch in self.arch_settings, \
+ f'"{arch}" is not one of the arch_settings ' \
+ f'({", ".join(self.arch_settings.keys())})'
+ self.arch_setting = self.arch_settings[arch]
+ # layer_settings of arch='l2' is 'b'
+ self.layer_setting = self.layer_settings['b' if arch ==
+ 'l2' else arch[:1]]
+ for index in out_indices:
+ if index not in range(0, len(self.layer_setting)):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, {len(self.layer_setting)}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(len(self.layer_setting) + 1):
+ raise ValueError('frozen_stages must be in range(0, '
+ f'{len(self.layer_setting) + 1}). '
+ f'But received {frozen_stages}')
+ self.drop_path_rate = drop_path_rate
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.layer_setting = model_scaling(self.layer_setting,
+ self.arch_setting)
+ block_cfg_0 = self.layer_setting[0][0]
+ block_cfg_last = self.layer_setting[-1][0]
+ self.in_channels = make_divisible(block_cfg_0[1], 8)
+ self.out_channels = block_cfg_last[1]
+ self.layers = nn.ModuleList()
+ self.layers.append(
+ ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=block_cfg_0[0],
+ stride=block_cfg_0[3],
+ padding=block_cfg_0[0] // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.make_layer()
+ self.layers.append(
+ ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=block_cfg_last[0],
+ stride=block_cfg_last[3],
+ padding=block_cfg_last[0] // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def make_layer(self):
+ # Without the first and the final conv block.
+ layer_setting = self.layer_setting[1:-1]
+
+ total_num_blocks = sum([len(x) for x in layer_setting])
+ block_idx = 0
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
+ ] # stochastic depth decay rule
+
+ for layer_cfg in layer_setting:
+ layer = []
+ for i, block_cfg in enumerate(layer_cfg):
+ (kernel_size, out_channels, se_ratio, stride, expand_ratio,
+ block_type) = block_cfg
+
+ mid_channels = int(self.in_channels * expand_ratio)
+ out_channels = make_divisible(out_channels, 8)
+ if se_ratio <= 0:
+ se_cfg = None
+ else:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=expand_ratio * se_ratio,
+ divisor=1,
+ act_cfg=(self.act_cfg, dict(type='Sigmoid')))
+ if block_type == 1: # edge tpu
+ if i > 0 and expand_ratio == 3:
+ with_residual = False
+ expand_ratio = 4
+ else:
+ with_residual = True
+ mid_channels = int(self.in_channels * expand_ratio)
+ if se_cfg is not None:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=se_ratio * expand_ratio,
+ divisor=1,
+ act_cfg=(self.act_cfg, dict(type='Sigmoid')))
+ block = partial(EdgeResidual, with_residual=with_residual)
+ else:
+ block = InvertedResidual
+ layer.append(
+ block(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ drop_path_rate=dpr[block_idx],
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+ block_idx += 1
+ self.layers.append(Sequential(*layer))
+
+ def forward(self, x):
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EfficientNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmpretrain/models/backbones/efficientnet_v2.py b/mmpretrain/models/backbones/efficientnet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec002a4dac46f756f00ed8f596b37028ba18c37
--- /dev/null
+++ b/mmpretrain/models/backbones/efficientnet_v2.py
@@ -0,0 +1,343 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence, Tuple
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import ConvModule, DropPath
+from mmengine.model import Sequential
+from torch import Tensor
+
+from mmpretrain.registry import MODELS
+from ..utils import InvertedResidual as MBConv
+from .base_backbone import BaseBackbone
+from .efficientnet import EdgeResidual as FusedMBConv
+
+
+class EnhancedConvModule(ConvModule):
+ """ConvModule with short-cut and droppath.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ has_skip (bool): Whether there is short-cut. Defaults to False.
+ drop_path_rate (float): Stochastic depth rate. Default 0.0.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ def __init__(self, *args, has_skip=False, drop_path_rate=0, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.has_skip = has_skip
+ if self.has_skip and (self.in_channels != self.out_channels
+ or self.stride != (1, 1)):
+ raise ValueError('the stride must be 1 and the `in_channels` and'
+ ' `out_channels` must be the same , when '
+ '`has_skip` is True in `EnhancedConvModule` .')
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate else nn.Identity()
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ short_cut = x
+ x = super().forward(x, **kwargs)
+ if self.has_skip:
+ x = self.drop_path(x) + short_cut
+ return x
+
+
+@MODELS.register_module()
+class EfficientNetV2(BaseBackbone):
+ """EfficientNetV2 backbone.
+
+ A PyTorch implementation of EfficientNetV2 introduced by:
+ `EfficientNetV2: Smaller Models and Faster Training
+ `_
+
+ Args:
+ arch (str): Architecture of efficientnetv2. Defaults to s.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ drop_path_rate (float): The ratio of the stochastic depth.
+ Defaults to 0.0.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to (-1, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to dict(type='Swish').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ """
+
+ # Parameters to build layers. From left to right:
+ # - repeat (int): The repeat number of the block in the layer
+ # - kernel_size (int): The kernel size of the layer
+ # - stride (int): The stride of the first block of the layer
+ # - expand_ratio (int, float): The expand_ratio of the mid_channels
+ # - in_channel (int): The number of in_channels of the layer
+ # - out_channel (int): The number of out_channels of the layer
+ # - se_ratio (float): The sequeeze ratio of SELayer.
+ # - block_type (int): -2: ConvModule, -1: EnhancedConvModule,
+ # 0: FusedMBConv, 1: MBConv
+ arch_settings = {
+ **dict.fromkeys(['small', 's'], [[2, 3, 1, 1, 24, 24, 0.0, -1],
+ [4, 3, 2, 4, 24, 48, 0.0, 0],
+ [4, 3, 2, 4, 48, 64, 0.0, 0],
+ [6, 3, 2, 4, 64, 128, 0.25, 1],
+ [9, 3, 1, 6, 128, 160, 0.25, 1],
+ [15, 3, 2, 6, 160, 256, 0.25, 1],
+ [1, 1, 1, 1, 256, 1280, 0.0, -2]]),
+ **dict.fromkeys(['m', 'medium'], [[3, 3, 1, 1, 24, 24, 0.0, -1],
+ [5, 3, 2, 4, 24, 48, 0.0, 0],
+ [5, 3, 2, 4, 48, 80, 0.0, 0],
+ [7, 3, 2, 4, 80, 160, 0.25, 1],
+ [14, 3, 1, 6, 160, 176, 0.25, 1],
+ [18, 3, 2, 6, 176, 304, 0.25, 1],
+ [5, 3, 1, 6, 304, 512, 0.25, 1],
+ [1, 1, 1, 1, 512, 1280, 0.0, -2]]),
+ **dict.fromkeys(['l', 'large'], [[4, 3, 1, 1, 32, 32, 0.0, -1],
+ [7, 3, 2, 4, 32, 64, 0.0, 0],
+ [7, 3, 2, 4, 64, 96, 0.0, 0],
+ [10, 3, 2, 4, 96, 192, 0.25, 1],
+ [19, 3, 1, 6, 192, 224, 0.25, 1],
+ [25, 3, 2, 6, 224, 384, 0.25, 1],
+ [7, 3, 1, 6, 384, 640, 0.25, 1],
+ [1, 1, 1, 1, 640, 1280, 0.0, -2]]),
+ **dict.fromkeys(['xl'], [[4, 3, 1, 1, 32, 32, 0.0, -1],
+ [8, 3, 2, 4, 32, 64, 0.0, 0],
+ [8, 3, 2, 4, 64, 96, 0.0, 0],
+ [16, 3, 2, 4, 96, 192, 0.25, 1],
+ [24, 3, 1, 6, 192, 256, 0.25, 1],
+ [32, 3, 2, 6, 256, 512, 0.25, 1],
+ [8, 3, 1, 6, 512, 640, 0.25, 1],
+ [1, 1, 1, 1, 640, 1280, 0.0, -2]]),
+ **dict.fromkeys(['b0'], [[1, 3, 1, 1, 32, 16, 0.0, -1],
+ [2, 3, 2, 4, 16, 32, 0.0, 0],
+ [2, 3, 2, 4, 32, 48, 0.0, 0],
+ [3, 3, 2, 4, 48, 96, 0.25, 1],
+ [5, 3, 1, 6, 96, 112, 0.25, 1],
+ [8, 3, 2, 6, 112, 192, 0.25, 1],
+ [1, 1, 1, 1, 192, 1280, 0.0, -2]]),
+ **dict.fromkeys(['b1'], [[2, 3, 1, 1, 32, 16, 0.0, -1],
+ [3, 3, 2, 4, 16, 32, 0.0, 0],
+ [3, 3, 2, 4, 32, 48, 0.0, 0],
+ [4, 3, 2, 4, 48, 96, 0.25, 1],
+ [6, 3, 1, 6, 96, 112, 0.25, 1],
+ [9, 3, 2, 6, 112, 192, 0.25, 1],
+ [1, 1, 1, 1, 192, 1280, 0.0, -2]]),
+ **dict.fromkeys(['b2'], [[2, 3, 1, 1, 32, 16, 0.0, -1],
+ [3, 3, 2, 4, 16, 32, 0.0, 0],
+ [3, 3, 2, 4, 32, 56, 0.0, 0],
+ [4, 3, 2, 4, 56, 104, 0.25, 1],
+ [6, 3, 1, 6, 104, 120, 0.25, 1],
+ [10, 3, 2, 6, 120, 208, 0.25, 1],
+ [1, 1, 1, 1, 208, 1408, 0.0, -2]]),
+ **dict.fromkeys(['b3'], [[2, 3, 1, 1, 40, 16, 0.0, -1],
+ [3, 3, 2, 4, 16, 40, 0.0, 0],
+ [3, 3, 2, 4, 40, 56, 0.0, 0],
+ [5, 3, 2, 4, 56, 112, 0.25, 1],
+ [7, 3, 1, 6, 112, 136, 0.25, 1],
+ [12, 3, 2, 6, 136, 232, 0.25, 1],
+ [1, 1, 1, 1, 232, 1536, 0.0, -2]])
+ }
+
+ def __init__(self,
+ arch: str = 's',
+ in_channels: int = 3,
+ drop_path_rate: float = 0.,
+ out_indices: Sequence[int] = (-1, ),
+ frozen_stages: int = 0,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=dict(type='BN', eps=1e-3, momentum=0.1),
+ act_cfg=dict(type='Swish'),
+ norm_eval: bool = False,
+ with_cp: bool = False,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ layer=['_BatchNorm', 'GroupNorm'],
+ val=1)
+ ]):
+ super(EfficientNetV2, self).__init__(init_cfg)
+ assert arch in self.arch_settings, \
+ f'"{arch}" is not one of the arch_settings ' \
+ f'({", ".join(self.arch_settings.keys())})'
+ self.arch = self.arch_settings[arch]
+ if frozen_stages not in range(len(self.arch) + 1):
+ raise ValueError('frozen_stages must be in range(0, '
+ f'{len(self.arch)}), but get {frozen_stages}')
+ self.drop_path_rate = drop_path_rate
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.layers = nn.ModuleList()
+ assert self.arch[-1][-1] == -2, \
+ f'the last block_type of `arch_setting` must be -2 ,' \
+ f'but get `{self.arch[-1][-1]}`'
+ self.in_channels = in_channels
+ self.out_channels = self.arch[-1][5]
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+
+ self.make_layers()
+
+ # there len(slef.arch) + 2 layers in the backbone
+ # including: the first + len(self.arch) layers + the last
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ out_indices = list(out_indices)
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = len(self.layers) + index
+ assert 0 <= out_indices[i] <= len(self.layers), \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ def make_layers(self, ):
+ # make the first layer
+ self.layers.append(
+ ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.arch[0][4],
+ kernel_size=3,
+ stride=2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ in_channels = self.arch[0][4]
+ layer_setting = self.arch[:-1]
+
+ total_num_blocks = sum([x[0] for x in layer_setting])
+ block_idx = 0
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
+ ] # stochastic depth decay rule
+
+ for layer_cfg in layer_setting:
+ layer = []
+ (repeat, kernel_size, stride, expand_ratio, _, out_channels,
+ se_ratio, block_type) = layer_cfg
+ for i in range(repeat):
+ stride = stride if i == 0 else 1
+ if block_type == -1:
+ has_skip = stride == 1 and in_channels == out_channels
+ droppath_rate = dpr[block_idx] if has_skip else 0.0
+ layer.append(
+ EnhancedConvModule(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ has_skip=has_skip,
+ drop_path_rate=droppath_rate,
+ stride=stride,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ in_channels = out_channels
+ else:
+ mid_channels = int(in_channels * expand_ratio)
+ se_cfg = None
+ if block_type != 0 and se_ratio > 0:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=expand_ratio * (1.0 / se_ratio),
+ divisor=1,
+ act_cfg=(self.act_cfg, dict(type='Sigmoid')))
+ block = FusedMBConv if block_type == 0 else MBConv
+ conv_cfg = self.conv_cfg if stride == 2 else None
+ layer.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ drop_path_rate=dpr[block_idx],
+ with_cp=self.with_cp))
+ in_channels = out_channels
+ block_idx += 1
+ self.layers.append(Sequential(*layer))
+
+ # make the last layer
+ self.layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ kernel_size=self.arch[-1][1],
+ stride=self.arch[-1][2],
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, x: Tensor) -> Tuple[Tensor]:
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EfficientNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmpretrain/models/backbones/hivit.py b/mmpretrain/models/backbones/hivit.py
new file mode 100644
index 0000000000000000000000000000000000000000..981cbf819138ace2c2e8441e7e65f927883c55fd
--- /dev/null
+++ b/mmpretrain/models/backbones/hivit.py
@@ -0,0 +1,656 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import DropPath
+from mmengine.model.weight_init import trunc_normal_
+
+from mmpretrain.registry import MODELS
+from ..utils import build_norm_layer, to_2tuple
+from .base_backbone import BaseBackbone
+
+
+class Mlp(nn.Module):
+ """MLP block.
+
+ Args:
+ in_features (int): Number of input dims.
+ hidden_features (int): Number of hidden dims.
+ out_feature (int): Number of out dims.
+ act_layer: MLP activation layer.
+ drop (float): MLP dropout rate.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ """Attention.
+
+ Args:
+ input size (int): Input size.
+ dim (int): Number of input dims.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): Enable bias for qkv projections if True.
+ qk_scale (float): The number of divider after q@k. Default to None.
+ attn_drop (float): The drop out rate for attention output weights.
+ Defaults to 0.
+ proj_drop (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ rpe (bool): If True, add relative position embedding to
+ the patch embedding.
+ """
+
+ def __init__(self,
+ input_size,
+ dim,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ rpe=True):
+ super().__init__()
+ self.input_size = input_size
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * input_size - 1) *
+ (2 * input_size - 1), num_heads)) if rpe else None
+ if rpe:
+ coords_h = torch.arange(input_size)
+ coords_w = torch.arange(input_size)
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :,
+ None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += input_size - 1
+ relative_coords[:, :, 1] += input_size - 1
+ relative_coords[:, :, 0] *= 2 * input_size - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, rpe_index=None, mask=None):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if rpe_index is not None:
+ rpe_index = self.relative_position_index.view(-1)
+ S = int(math.sqrt(rpe_index.size(-1)))
+ relative_position_bias = self.relative_position_bias_table[
+ rpe_index].view(-1, S, S, self.num_heads)
+ relative_position_bias = relative_position_bias.permute(
+ 0, 3, 1, 2).contiguous()
+ attn = attn + relative_position_bias
+ if mask is not None:
+ mask = mask.bool()
+ attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class BlockWithRPE(nn.Module):
+ """HiViT block.
+
+ Args:
+ input_size (int): Input size.
+ dim (int): Number of input dims.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (int): Ratio of MLP hidden dim to embedding dim.
+ qkv_bias (bool): Enable bias for qkv projections if True.
+ qk_scale (float): The number of divider after q@k. Default to None.
+ drop (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ attn_drop (float): The drop out rate for attention output weights.
+ Defaults to 0.
+ drop_path (float): Stochastic depth rate. Defaults to 0.
+ rpe (bool): If True, add relative position embedding to
+ the patch embedding.
+ layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
+ act_layer: MLP activation layer.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ """
+
+ def __init__(self,
+ input_size,
+ dim,
+ num_heads=0.,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ rpe=True,
+ layer_scale_init_value=0.0,
+ act_layer=nn.GELU,
+ norm_cfg=dict(type='LN')):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.mlp_ratio = mlp_ratio
+
+ with_attn = num_heads > 0.
+
+ self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None
+ self.attn = Attention(
+ input_size,
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ rpe=rpe,
+ ) if with_attn else None
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = build_norm_layer(norm_cfg, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ if layer_scale_init_value > 0:
+ self.gamma_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones(
+ (dim)), requires_grad=True) if with_attn else None
+ self.gamma_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rpe_index=None, mask=None):
+ if self.attn is not None:
+ if self.gamma_1 is not None:
+ x = x + self.drop_path(
+ self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask))
+ else:
+ x = x + self.drop_path(
+ self.attn(self.norm1(x), rpe_index, mask))
+ if self.gamma_2 is not None:
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """PatchEmbed for HiViT.
+
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size. Defaults to 16.
+ inner_patches (int): Inner patch. Defaults to 4.
+ in_chans (int): Number of image input channels.
+ embed_dim (int): Transformer embedding dimension.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ kernel_size (int): Kernel size.
+ pad_size (int): Pad size.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ inner_patches=4,
+ in_chans=3,
+ embed_dim=128,
+ norm_cfg=None,
+ kernel_size=None,
+ pad_size=None):
+ super().__init__()
+ img_size = to_2tuple(img_size) if not isinstance(img_size,
+ tuple) else img_size
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ ]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.inner_patches = inner_patches
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ conv_size = [size // inner_patches for size in patch_size]
+ kernel_size = kernel_size or conv_size
+ pad_size = pad_size or 0
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=conv_size,
+ padding=pad_size)
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ patches_resolution = (H // self.patch_size[0], W // self.patch_size[1])
+ num_patches = patches_resolution[0] * patches_resolution[1]
+ x = self.proj(x).view(
+ B,
+ -1,
+ patches_resolution[0],
+ self.inner_patches,
+ patches_resolution[1],
+ self.inner_patches,
+ ).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches,
+ self.inner_patches, -1)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+class PatchMerge(nn.Module):
+ """PatchMerge for HiViT.
+
+ Args:
+ dim (int): Number of input channels.
+ norm_cfg (dict): Config dict for normalization layer.
+ """
+
+ def __init__(self, dim, norm_cfg):
+ super().__init__()
+ self.norm = build_norm_layer(norm_cfg, dim * 4)
+ self.reduction = nn.Linear(dim * 4, dim * 2, bias=False)
+
+ def forward(self, x, *args, **kwargs):
+ is_main_stage = len(x.shape) == 3
+ if is_main_stage:
+ B, N, C = x.shape
+ S = int(math.sqrt(N))
+ x = x.reshape(B, S // 2, 2, S // 2, 2, C) \
+ .permute(0, 1, 3, 2, 4, 5) \
+ .reshape(B, -1, 2, 2, C)
+ x0 = x[..., 0::2, 0::2, :]
+ x1 = x[..., 1::2, 0::2, :]
+ x2 = x[..., 0::2, 1::2, :]
+ x3 = x[..., 1::2, 1::2, :]
+
+ x = torch.cat([x0, x1, x2, x3], dim=-1)
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ if is_main_stage:
+ x = x[:, :, 0, 0, :]
+ return x
+
+
+@MODELS.register_module()
+class HiViT(BaseBackbone):
+ """HiViT.
+
+ A PyTorch implement of: `HiViT: A Simple and More Efficient Design
+ of Hierarchical Vision Transformer `_.
+
+ Args:
+ arch (str | dict): Swin Transformer architecture. If use string, choose
+ from 'tiny', 'small', and'base'. If use dict, it should
+ have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **depths** (List[int]): The number of blocks in each stage.
+ - **num_heads** (int): The number of heads in attention
+ modules of each stage.
+
+ Defaults to 'tiny'.
+ img_size (int): Input image size.
+ patch_size (int): Patch size. Defaults to 16.
+ inner_patches (int): Inner patch. Defaults to 4.
+ in_chans (int): Number of image input channels.
+ embed_dim (int): Transformer embedding dimension.
+ depths (list[int]): Number of successive HiViT blocks.
+ num_heads (int): Number of attention heads.
+ stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
+ in the first two stages.
+ mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
+ the last stage.
+ qkv_bias (bool): Enable bias for qkv projections if True.
+ qk_scale (float): The number of divider after q@k. Default to None.
+ drop_rate (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ attn_drop_rate (float): The drop out rate for attention output weights.
+ Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ ape (bool): If True, add absolute position embedding to
+ the patch embedding.
+ rpe (bool): If True, add relative position embedding to
+ the patch embedding.
+ patch_norm (bool): If True, use norm_cfg for normalization layer.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ kernel_size (int): Kernel size.
+ pad_size (int): Pad size.
+ layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
+ init_cfg (dict, optional): The extra config for initialization.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'],
+ {'embed_dims': 384,
+ 'depths': [1, 1, 10],
+ 'num_heads': 6}),
+ **dict.fromkeys(['s', 'small'],
+ {'embed_dims': 384,
+ 'depths': [2, 2, 20],
+ 'num_heads': 6}),
+ **dict.fromkeys(['b', 'base'],
+ {'embed_dims': 512,
+ 'depths': [2, 2, 24],
+ 'num_heads': 8}),
+ **dict.fromkeys(['l', 'large'],
+ {'embed_dims': 768,
+ 'depths': [2, 2, 40],
+ 'num_heads': 12}),
+ } # yapf: disable
+
+ num_extra_tokens = 0
+
+ def __init__(self,
+ arch='base',
+ img_size=224,
+ patch_size=16,
+ inner_patches=4,
+ in_chans=3,
+ stem_mlp_ratio=3.,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.0,
+ norm_cfg=dict(type='LN'),
+ out_indices=[23],
+ ape=True,
+ rpe=False,
+ patch_norm=True,
+ frozen_stages=-1,
+ kernel_size=None,
+ pad_size=None,
+ layer_scale_init_value=0.0,
+ init_cfg=None):
+ super(HiViT, self).__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'embed_dims', 'depths', 'num_heads'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+
+ self.num_stages = len(self.depths)
+ self.ape = ape
+ self.rpe = rpe
+ self.patch_size = patch_size
+ self.num_features = self.embed_dims
+ self.mlp_ratio = mlp_ratio
+ self.num_main_blocks = self.depths[-1]
+ self.out_indices = out_indices
+ self.out_indices[-1] = self.depths[-1] - 1
+
+ img_size = to_2tuple(img_size) if not isinstance(img_size,
+ tuple) else img_size
+
+ embed_dim = self.embed_dims // 2**(self.num_stages - 1)
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ inner_patches=inner_patches,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_cfg=norm_cfg if patch_norm else None,
+ kernel_size=kernel_size,
+ pad_size=pad_size)
+ num_patches = self.patch_embed.num_patches
+ Hp, Wp = self.patch_embed.patches_resolution
+
+ if rpe:
+ assert Hp == Wp, 'If you use relative position, make sure H == W '
+ 'of input size'
+
+ # absolute position embedding
+ if ape:
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, self.num_features))
+ trunc_normal_(self.pos_embed, std=.02)
+ if rpe:
+ # get pair-wise relative position index for each token inside the
+ # window
+ coords_h = torch.arange(Hp)
+ coords_w = torch.arange(Wp)
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :,
+ None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += Hp - 1
+ relative_coords[:, :, 1] += Wp - 1
+ relative_coords[:, :, 0] *= 2 * Wp - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = iter(
+ x.item()
+ for x in torch.linspace(0, drop_path_rate,
+ sum(self.depths) + sum(self.depths[:-1])))
+
+ # build blocks
+ self.blocks = nn.ModuleList()
+ for stage_i, stage_depth in enumerate(self.depths):
+ is_main_stage = embed_dim == self.num_features
+ nhead = self.num_heads if is_main_stage else 0
+ ratio = mlp_ratio if is_main_stage else stem_mlp_ratio
+ # every block not in main stage includes two mlp blocks
+ stage_depth = stage_depth if is_main_stage else stage_depth * 2
+ for _ in range(stage_depth):
+ self.blocks.append(
+ BlockWithRPE(
+ Hp,
+ embed_dim,
+ nhead,
+ ratio,
+ qkv_bias,
+ qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=next(dpr),
+ rpe=rpe,
+ norm_cfg=norm_cfg,
+ layer_scale_init_value=layer_scale_init_value,
+ ))
+ if stage_i + 1 < self.num_stages:
+ self.blocks.append(PatchMerge(embed_dim, norm_cfg))
+ embed_dim *= 2
+
+ self.frozen_stages = frozen_stages
+ if self.frozen_stages > 0:
+ self._freeze_stages()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def interpolate_pos_encoding(self, x, h, w):
+ npatch = x.shape[1]
+ N = self.pos_embed.shape[1]
+ if npatch == N and w == h:
+ return self.pos_embed
+ patch_pos_embed = self.pos_embed
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
+ dim).permute(0, 3, 1, 2),
+ scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
+ mode='bicubic',
+ )
+ assert int(h0) == patch_pos_embed.shape[-2] and int(
+ w0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ Hp, Wp = H // self.patch_size, W // self.patch_size
+
+ x = self.patch_embed(x)
+
+ outs = []
+ for i, blk in enumerate(self.blocks[:-self.num_main_blocks]):
+ x = blk(x)
+ if i in self.out_indices:
+ x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute(
+ 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3],
+ Wp * x.shape[-2]).contiguous()
+ outs.append(x)
+
+ x = x[..., 0, 0, :]
+ if self.ape:
+ x = x + self.interpolate_pos_encoding(x, H, W)
+ x = self.pos_drop(x)
+
+ rpe_index = True if self.rpe else None
+
+ for i, blk in enumerate(self.blocks[-self.num_main_blocks:]):
+ x = blk(x, rpe_index)
+ if i in self.out_indices:
+ x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous()
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ # freeze position embedding
+ if self.pos_embed is not None:
+ self.pos_embed.requires_grad = False
+ # set dropout to eval model
+ self.pos_drop.eval()
+ # freeze patch embedding
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ # freeze layers
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i - 1]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+ # freeze the last layer norm
+ for param in self.fc_norm.parameters():
+ param.requires_grad = False
+
+ def get_layer_depth(self, param_name: str, prefix: str = ''):
+ """Get the layer-wise depth of a parameter.
+
+ Args:
+ param_name (str): The name of the parameter.
+ prefix (str): The prefix for the parameter.
+ Defaults to an empty string.
+
+ Returns:
+ Tuple[int, int]: The layer-wise depth and the num of layers.
+
+ Note:
+ The first depth is the stem module (``layer_depth=0``), and the
+ last depth is the subsequent module (``layer_depth=num_layers-1``)
+ """
+ self.num_layers = len(self.blocks)
+ num_layers = self.num_layers + 2
+
+ if not param_name.startswith(prefix):
+ # For subsequent module like head
+ return num_layers - 1, num_layers
+
+ param_name = param_name[len(prefix):]
+
+ if param_name in 'pos_embed':
+ layer_depth = 0
+ elif param_name.startswith('patch_embed'):
+ layer_depth = 0
+ elif param_name.startswith('layers'):
+ layer_id = int(param_name.split('.')[1])
+ layer_depth = layer_id + 1
+ else:
+ layer_depth = num_layers - 1
+
+ return layer_depth, num_layers
diff --git a/mmpretrain/models/backbones/hornet.py b/mmpretrain/models/backbones/hornet.py
new file mode 100644
index 0000000000000000000000000000000000000000..460f2dc57975712b5eae8308e2fca9c38b89a3e2
--- /dev/null
+++ b/mmpretrain/models/backbones/hornet.py
@@ -0,0 +1,500 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Adapted from official impl at https://github.com/raoyongming/HorNet.
+try:
+ import torch.fft
+ fft = True
+except ImportError:
+ fft = None
+
+import copy
+from functools import partial
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from mmcv.cnn.bricks import DropPath
+
+from mmpretrain.models.backbones.base_backbone import BaseBackbone
+from mmpretrain.registry import MODELS
+from ..utils import LayerScale
+
+
+def get_dwconv(dim, kernel_size, bias=True):
+ """build a pepth-wise convolution."""
+ return nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ bias=bias,
+ groups=dim)
+
+
+class HorNetLayerNorm(nn.Module):
+ """An implementation of LayerNorm of HorNet.
+
+ The differences between HorNetLayerNorm & torch LayerNorm:
+ 1. Supports two data formats channels_last or channels_first.
+ Args:
+ normalized_shape (int or list or torch.Size): input shape from an
+ expected input of size.
+ eps (float): a value added to the denominator for numerical stability.
+ Defaults to 1e-5.
+ data_format (str): The ordering of the dimensions in the inputs.
+ channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with
+ shape (batch_size, channels, height, width).
+ Defaults to 'channels_last'.
+ """
+
+ def __init__(self,
+ normalized_shape,
+ eps=1e-6,
+ data_format='channels_last'):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ['channels_last', 'channels_first']:
+ raise ValueError(
+ 'data_format must be channels_last or channels_first')
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == 'channels_last':
+ return F.layer_norm(x, self.normalized_shape, self.weight,
+ self.bias, self.eps)
+ elif self.data_format == 'channels_first':
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class GlobalLocalFilter(nn.Module):
+ """A GlobalLocalFilter of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ h (int): Height of complex_weight.
+ Defaults to 14.
+ w (int): Width of complex_weight.
+ Defaults to 8.
+ """
+
+ def __init__(self, dim, h=14, w=8):
+ super().__init__()
+ self.dw = nn.Conv2d(
+ dim // 2,
+ dim // 2,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ groups=dim // 2)
+ self.complex_weight = nn.Parameter(
+ torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
+ self.pre_norm = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+ self.post_norm = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ x1, x2 = torch.chunk(x, 2, dim=1)
+ x1 = self.dw(x1)
+
+ x2 = x2.to(torch.float32)
+ B, C, a, b = x2.shape
+ x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
+
+ weight = self.complex_weight
+ if not weight.shape[1:3] == x2.shape[2:4]:
+ weight = F.interpolate(
+ weight.permute(3, 0, 1, 2),
+ size=x2.shape[2:4],
+ mode='bilinear',
+ align_corners=True).permute(1, 2, 3, 0)
+
+ weight = torch.view_as_complex(weight.contiguous())
+
+ x2 = x2 * weight
+ x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
+
+ x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)],
+ dim=2).reshape(B, 2 * C, a, b)
+ x = self.post_norm(x)
+ return x
+
+
+class gnConv(nn.Module):
+ """A gnConv of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ order (int): Order of gnConv.
+ Defaults to 5.
+ dw_cfg (dict): The Config for dw conv.
+ Defaults to ``dict(type='DW', kernel_size=7)``.
+ scale (float): Scaling parameter of gflayer outputs.
+ Defaults to 1.0.
+ """
+
+ def __init__(self,
+ dim,
+ order=5,
+ dw_cfg=dict(type='DW', kernel_size=7),
+ scale=1.0):
+ super().__init__()
+ self.order = order
+ self.dims = [dim // 2**i for i in range(order)]
+ self.dims.reverse()
+ self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
+
+ cfg = copy.deepcopy(dw_cfg)
+ dw_type = cfg.pop('type')
+ assert dw_type in ['DW', 'GF'],\
+ 'dw_type should be `DW` or `GF`'
+ if dw_type == 'DW':
+ self.dwconv = get_dwconv(sum(self.dims), **cfg)
+ elif dw_type == 'GF':
+ self.dwconv = GlobalLocalFilter(sum(self.dims), **cfg)
+
+ self.proj_out = nn.Conv2d(dim, dim, 1)
+
+ self.projs = nn.ModuleList([
+ nn.Conv2d(self.dims[i], self.dims[i + 1], 1)
+ for i in range(order - 1)
+ ])
+
+ self.scale = scale
+
+ def forward(self, x):
+ x = self.proj_in(x)
+ y, x = torch.split(x, (self.dims[0], sum(self.dims)), dim=1)
+
+ x = self.dwconv(x) * self.scale
+
+ dw_list = torch.split(x, self.dims, dim=1)
+ x = y * dw_list[0]
+
+ for i in range(self.order - 1):
+ x = self.projs[i](x) * dw_list[i + 1]
+
+ x = self.proj_out(x)
+
+ return x
+
+
+class HorNetBlock(nn.Module):
+ """A block of HorNet.
+
+ Args:
+ dim (int): Number of input channels.
+ order (int): Order of gnConv.
+ Defaults to 5.
+ dw_cfg (dict): The Config for dw conv.
+ Defaults to ``dict(type='DW', kernel_size=7)``.
+ scale (float): Scaling parameter of gflayer outputs.
+ Defaults to 1.0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ use_layer_scale (bool): Whether to use use_layer_scale in HorNet
+ block. Defaults to True.
+ """
+
+ def __init__(self,
+ dim,
+ order=5,
+ dw_cfg=dict(type='DW', kernel_size=7),
+ scale=1.0,
+ drop_path_rate=0.,
+ use_layer_scale=True):
+ super().__init__()
+ self.out_channels = dim
+
+ self.norm1 = HorNetLayerNorm(
+ dim, eps=1e-6, data_format='channels_first')
+ self.gnconv = gnConv(dim, order, dw_cfg, scale)
+ self.norm2 = HorNetLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+
+ if use_layer_scale:
+ self.gamma1 = LayerScale(dim, data_format='channels_first')
+ self.gamma2 = LayerScale(dim)
+ else:
+ self.gamma1, self.gamma2 = nn.Identity(), nn.Identity()
+
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.gamma1(self.gnconv(self.norm1(x))))
+
+ input = x
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm2(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ x = self.gamma2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+@MODELS.register_module()
+class HorNet(BaseBackbone):
+ """HorNet backbone.
+
+ A PyTorch implementation of paper `HorNet: Efficient High-Order Spatial
+ Interactions with Recursive Gated Convolutions
+ `_ .
+ Inspiration from https://github.com/raoyongming/HorNet
+
+ Args:
+ arch (str | dict): HorNet architecture.
+
+ If use string, choose from 'tiny', 'small', 'base' and 'large'.
+ If use dict, it should have below keys:
+
+ - **base_dim** (int): The base dimensions of embedding.
+ - **depths** (List[int]): The number of blocks in each stage.
+ - **orders** (List[int]): The number of order of gnConv in each
+ stage.
+ - **dw_cfg** (List[dict]): The Config for dw conv.
+
+ Defaults to 'tiny'.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ scale (float): Scaling parameter of gflayer outputs. Defaults to 1/3.
+ use_layer_scale (bool): Whether to use use_layer_scale in HorNet
+ block. Defaults to True.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: ``(3, )``.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ gap_before_final_norm (bool): Whether to globally average the feature
+ map before the final norm layer. In the official repo, it's only
+ used in classification task. Defaults to True.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['t', 'tiny'],
+ {'base_dim': 64,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['t-gf', 'tiny-gf'],
+ {'base_dim': 64,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['s', 'small'],
+ {'base_dim': 96,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['s-gf', 'small-gf'],
+ {'base_dim': 96,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['b', 'base'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['b-gf', 'base-gf'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['b-gf384', 'base-gf384'],
+ {'base_dim': 128,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=24, w=12),
+ dict(type='GF', h=13, w=7)]}),
+ **dict.fromkeys(['l', 'large'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [dict(type='DW', kernel_size=7)] * 4}),
+ **dict.fromkeys(['l-gf', 'large-gf'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=14, w=8),
+ dict(type='GF', h=7, w=4)]}),
+ **dict.fromkeys(['l-gf384', 'large-gf384'],
+ {'base_dim': 192,
+ 'depths': [2, 3, 18, 2],
+ 'orders': [2, 3, 4, 5],
+ 'dw_cfg': [
+ dict(type='DW', kernel_size=7),
+ dict(type='DW', kernel_size=7),
+ dict(type='GF', h=24, w=12),
+ dict(type='GF', h=13, w=7)]}),
+ } # yapf: disable
+
+ def __init__(self,
+ arch='tiny',
+ in_channels=3,
+ drop_path_rate=0.,
+ scale=1 / 3,
+ use_layer_scale=True,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ with_cp=False,
+ gap_before_final_norm=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if fft is None:
+ raise RuntimeError(
+ 'Failed to import torch.fft. Please install "torch>=1.7".')
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'base_dim', 'depths', 'orders', 'dw_cfg'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.scale = scale
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.with_cp = with_cp
+ self.gap_before_final_norm = gap_before_final_norm
+
+ base_dim = self.arch_settings['base_dim']
+ dims = list(map(lambda x: 2**x * base_dim, range(4)))
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4),
+ HorNetLayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ HorNetLayerNorm(
+ dims[i], eps=1e-6, data_format='channels_first'),
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ total_depth = sum(self.arch_settings['depths'])
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ] # stochastic depth decay rule
+
+ cur_block_idx = 0
+ self.stages = nn.ModuleList()
+ for i in range(4):
+ stage = nn.Sequential(*[
+ HorNetBlock(
+ dim=dims[i],
+ order=self.arch_settings['orders'][i],
+ dw_cfg=self.arch_settings['dw_cfg'][i],
+ scale=self.scale,
+ drop_path_rate=dpr[cur_block_idx + j],
+ use_layer_scale=use_layer_scale)
+ for j in range(self.arch_settings['depths'][i])
+ ])
+ self.stages.append(stage)
+ cur_block_idx += self.arch_settings['depths'][i]
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ out_indices = list(out_indices)
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = len(self.stages) + index
+ assert 0 <= out_indices[i] <= len(self.stages), \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ norm_layer = partial(
+ HorNetLayerNorm, eps=1e-6, data_format='channels_first')
+ for i_layer in out_indices:
+ layer = norm_layer(dims[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ def train(self, mode=True):
+ super(HorNet, self).train(mode)
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ for i in range(0, self.frozen_stages + 1):
+ # freeze patch embed
+ m = self.downsample_layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ # freeze blocks
+ m = self.stages[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if i in self.out_indices:
+ # freeze norm
+ m = getattr(self, f'norm{i + 1}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ outs = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ if self.with_cp:
+ x = checkpoint.checkpoint_sequential(self.stages[i],
+ len(self.stages[i]), x)
+ else:
+ x = self.stages[i](x)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ if self.gap_before_final_norm:
+ gap = x.mean([-2, -1], keepdim=True)
+ outs.append(norm_layer(gap).flatten(1))
+ else:
+ # The output of LayerNorm2d may be discontiguous, which
+ # may cause some problem in the downstream tasks
+ outs.append(norm_layer(x).contiguous())
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/hrnet.py b/mmpretrain/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..99afa908531326f05ff1c977f0146a528683af43
--- /dev/null
+++ b/mmpretrain/models/backbones/hrnet.py
@@ -0,0 +1,563 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule, ModuleList, Sequential
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
+
+
+class HRModule(BaseModule):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+
+ Args:
+ num_branches (int): The number of branches.
+ block (``BaseModule``): Convolution block module.
+ num_blocks (tuple): The number of blocks in each branch.
+ The length must be equal to ``num_branches``.
+ num_channels (tuple): The number of base channels in each branch.
+ The length must be equal to ``num_branches``.
+ multiscale_output (bool): Whether to output multi-level features
+ produced by multiple branches. If False, only the first level
+ feature will be output. Defaults to True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ conv_cfg (dict, optional): Dictionary to construct and config conv
+ layer. Defaults to None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Defaults to ``dict(type='BN')``.
+ block_init_cfg (dict, optional): The initialization configs of every
+ blocks. Defaults to None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ block_init_cfg=None,
+ init_cfg=None):
+ super(HRModule, self).__init__(init_cfg)
+ self.block_init_cfg = block_init_cfg
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, block, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_BLOCKS({len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_CHANNELS({len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ out_channels = num_channels[i] * get_expansion(block)
+ branches.append(
+ ResLayer(
+ block=block,
+ num_blocks=num_blocks[i],
+ in_channels=self.in_channels[i],
+ out_channels=out_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ with_cp=self.with_cp,
+ init_cfg=self.block_init_cfg,
+ ))
+
+ return ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ # Upsample the feature maps of smaller scales.
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i), mode='nearest')))
+ elif j == i:
+ # Keep the feature map with the same scale.
+ fuse_layer.append(None)
+ else:
+ # Downsample the feature maps of larger scales.
+ conv_downsamples = []
+ for k in range(i - j):
+ # Use stacked convolution layers to downsample.
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@MODELS.register_module()
+class HRNet(BaseModule):
+ """HRNet backbone.
+
+ `High-Resolution Representations for Labeling Pixels and Regions
+ `_.
+
+ Args:
+ arch (str): The preset HRNet architecture, includes 'w18', 'w30',
+ 'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if
+ extra is ``None``. Defaults to 'w32'.
+ extra (dict, optional): Detailed configuration for each stage of HRNet.
+ There must be 4 stages, the configuration for each stage must have
+ 5 keys:
+
+ - num_modules (int): The number of HRModule in this stage.
+ - num_branches (int): The number of branches in the HRModule.
+ - block (str): The type of convolution block. Please choose between
+ 'BOTTLENECK' and 'BASIC'.
+ - num_blocks (tuple): The number of blocks in each branch.
+ The length must be equal to num_branches.
+ - num_channels (tuple): The number of base channels in each branch.
+ The length must be equal to num_branches.
+
+ Defaults to None.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ conv_cfg (dict, optional): Dictionary to construct and config conv
+ layer. Defaults to None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Defaults to ``dict(type='BN')``.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Defaults to False.
+ multiscale_output (bool): Whether to output multi-level features
+ produced by multiple branches. If False, only the first level
+ feature will be output. Defaults to True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+
+ Example:
+ >>> import torch
+ >>> from mmpretrain.models import HRNet
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+ arch_zoo = {
+ # num_modules, num_branches, block, num_blocks, num_channels
+ 'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (18, 36)],
+ [4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]],
+ 'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (30, 60)],
+ [4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]],
+ 'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (32, 64)],
+ [4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]],
+ 'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (40, 80)],
+ [4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]],
+ 'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (44, 88)],
+ [4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]],
+ 'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (48, 96)],
+ [4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]],
+ 'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
+ [1, 2, 'BASIC', (4, 4), (64, 128)],
+ [4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)],
+ [3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]],
+ } # yapf:disable
+
+ def __init__(self,
+ arch='w32',
+ extra=None,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False,
+ multiscale_output=True,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(HRNet, self).__init__(init_cfg)
+
+ extra = self.parse_arch(arch, extra)
+
+ # Assert configurations of 4 stages are in extra
+ for i in range(1, 5):
+ assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".'
+ # Assert whether the length of `num_blocks` and `num_channels` are
+ # equal to `num_branches`
+ cfg = extra[f'stage{i}']
+ assert len(cfg['num_blocks']) == cfg['num_branches'] and \
+ len(cfg['num_channels']) == cfg['num_branches']
+
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # -------------------- stem net --------------------
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels=64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # -------------------- stage 1 --------------------
+ self.stage1_cfg = self.extra['stage1']
+ base_channels = self.stage1_cfg['num_channels']
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [
+ channel * get_expansion(block) for channel in base_channels
+ ]
+ # To align with the original code, use layer1 instead of stage1 here.
+ self.layer1 = ResLayer(
+ block,
+ in_channels=64,
+ out_channels=num_channels[0],
+ num_blocks=num_blocks[0])
+ pre_num_channels = num_channels
+
+ # -------------------- stage 2~4 --------------------
+ for i in range(2, 5):
+ stage_cfg = self.extra[f'stage{i}']
+ base_channels = stage_cfg['num_channels']
+ block = self.blocks_dict[stage_cfg['block']]
+ multiscale_output_ = multiscale_output if i == 4 else True
+
+ num_channels = [
+ channel * get_expansion(block) for channel in base_channels
+ ]
+ # The transition layer from layer1 to stage2
+ transition = self._make_transition_layer(pre_num_channels,
+ num_channels)
+ self.add_module(f'transition{i-1}', transition)
+ stage = self._make_stage(
+ stage_cfg, num_channels, multiscale_output=multiscale_output_)
+ self.add_module(f'stage{i}', stage)
+
+ pre_num_channels = num_channels
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ # For existing scale branches,
+ # add conv block when the channels are not the same.
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(nn.Identity())
+ else:
+ # For new scale branches, add stacked downsample conv blocks.
+ # For example, num_branches_pre = 2, for the 4th branch, add
+ # stacked two downsample conv blocks.
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ block_init_cfg = None
+ if self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm3'))
+
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ block_init_cfg=block_init_cfg))
+
+ return Sequential(*hr_modules)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = [x]
+
+ for i in range(2, 5):
+ # Apply transition
+ transition = getattr(self, f'transition{i-1}')
+ inputs = []
+ for j, layer in enumerate(transition):
+ if j < len(x_list):
+ inputs.append(layer(x_list[j]))
+ else:
+ inputs.append(layer(x_list[-1]))
+ # Forward HRModule
+ stage = getattr(self, f'stage{i}')
+ x_list = stage(inputs)
+
+ return tuple(x_list)
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def parse_arch(self, arch, extra=None):
+ if extra is not None:
+ return extra
+
+ assert arch in self.arch_zoo, \
+ ('Invalid arch, please choose arch from '
+ f'{list(self.arch_zoo.keys())}, or specify `extra` '
+ 'argument directly.')
+
+ extra = dict()
+ for i, stage_setting in enumerate(self.arch_zoo[arch], start=1):
+ extra[f'stage{i}'] = dict(
+ num_modules=stage_setting[0],
+ num_branches=stage_setting[1],
+ block=stage_setting[2],
+ num_blocks=stage_setting[3],
+ num_channels=stage_setting[4],
+ )
+
+ return extra
diff --git a/mmpretrain/models/backbones/inception_v3.py b/mmpretrain/models/backbones/inception_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d6c04b9fba4b50fce31539d14874dc7a47a539a
--- /dev/null
+++ b/mmpretrain/models/backbones/inception_v3.py
@@ -0,0 +1,501 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer
+from mmengine.model import BaseModule
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class BasicConv2d(BaseModule):
+ """A basic convolution block including convolution, batch norm and ReLU.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ conv_cfg (dict, optional): The config of convolution layer.
+ Defaults to None, which means to use ``nn.Conv2d``.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ **kwargs: Other keyword arguments of the convolution layer.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg: Optional[dict] = None,
+ **kwargs) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.conv = build_conv_layer(
+ conv_cfg, in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ x = self.conv(x)
+ x = self.bn(x)
+ return self.relu(x)
+
+
+class InceptionA(BaseModule):
+ """Type-A Inception block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ pool_features (int): The number of channels in pooling branch.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ pool_features: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg: Optional[dict] = None):
+ super().__init__(init_cfg=init_cfg)
+ self.branch1x1 = BasicConv2d(
+ in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
+
+ self.branch5x5_1 = BasicConv2d(
+ in_channels, 48, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch5x5_2 = BasicConv2d(
+ 48, 64, kernel_size=5, padding=2, conv_cfg=conv_cfg)
+
+ self.branch3x3dbl_1 = BasicConv2d(
+ in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_2 = BasicConv2d(
+ 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_3 = BasicConv2d(
+ 96, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
+
+ self.branch_pool_downsample = nn.AvgPool2d(
+ kernel_size=3, stride=1, padding=1)
+ self.branch_pool = BasicConv2d(
+ in_channels, pool_features, kernel_size=1, conv_cfg=conv_cfg)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = self.branch_pool_downsample(x)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionB(BaseModule):
+ """Type-B Inception block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg: Optional[dict] = None):
+ super().__init__(init_cfg=init_cfg)
+ self.branch3x3 = BasicConv2d(
+ in_channels, 384, kernel_size=3, stride=2, conv_cfg=conv_cfg)
+
+ self.branch3x3dbl_1 = BasicConv2d(
+ in_channels, 64, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_2 = BasicConv2d(
+ 64, 96, kernel_size=3, padding=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_3 = BasicConv2d(
+ 96, 96, kernel_size=3, stride=2, conv_cfg=conv_cfg)
+
+ self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ branch3x3 = self.branch3x3(x)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = self.branch_pool(x)
+
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionC(BaseModule):
+ """Type-C Inception block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ channels_7x7 (int): The number of channels in 7x7 convolution branch.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ channels_7x7: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.branch1x1 = BasicConv2d(
+ in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
+
+ c7 = channels_7x7
+ self.branch7x7_1 = BasicConv2d(
+ in_channels, c7, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch7x7_2 = BasicConv2d(
+ c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
+ self.branch7x7_3 = BasicConv2d(
+ c7, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
+
+ self.branch7x7dbl_1 = BasicConv2d(
+ in_channels, c7, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch7x7dbl_2 = BasicConv2d(
+ c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
+ self.branch7x7dbl_3 = BasicConv2d(
+ c7, c7, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
+ self.branch7x7dbl_4 = BasicConv2d(
+ c7, c7, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
+ self.branch7x7dbl_5 = BasicConv2d(
+ c7, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
+
+ self.branch_pool_downsample = nn.AvgPool2d(
+ kernel_size=3, stride=1, padding=1)
+ self.branch_pool = BasicConv2d(
+ in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ branch_pool = self.branch_pool_downsample(x)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionD(BaseModule):
+ """Type-D Inception block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg: Optional[dict] = None):
+ super().__init__(init_cfg=init_cfg)
+ self.branch3x3_1 = BasicConv2d(
+ in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch3x3_2 = BasicConv2d(
+ 192, 320, kernel_size=3, stride=2, conv_cfg=conv_cfg)
+
+ self.branch7x7x3_1 = BasicConv2d(
+ in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch7x7x3_2 = BasicConv2d(
+ 192, 192, kernel_size=(1, 7), padding=(0, 3), conv_cfg=conv_cfg)
+ self.branch7x7x3_3 = BasicConv2d(
+ 192, 192, kernel_size=(7, 1), padding=(3, 0), conv_cfg=conv_cfg)
+ self.branch7x7x3_4 = BasicConv2d(
+ 192, 192, kernel_size=3, stride=2, conv_cfg=conv_cfg)
+
+ self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = self.branch3x3_2(branch3x3)
+
+ branch7x7x3 = self.branch7x7x3_1(x)
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+ branch_pool = self.branch_pool(x)
+ outputs = [branch3x3, branch7x7x3, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionE(BaseModule):
+ """Type-E Inception block.
+
+ Args:
+ in_channels (int): The number of input channels.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.branch1x1 = BasicConv2d(
+ in_channels, 320, kernel_size=1, conv_cfg=conv_cfg)
+
+ self.branch3x3_1 = BasicConv2d(
+ in_channels, 384, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch3x3_2a = BasicConv2d(
+ 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg)
+ self.branch3x3_2b = BasicConv2d(
+ 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg)
+
+ self.branch3x3dbl_1 = BasicConv2d(
+ in_channels, 448, kernel_size=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_2 = BasicConv2d(
+ 448, 384, kernel_size=3, padding=1, conv_cfg=conv_cfg)
+ self.branch3x3dbl_3a = BasicConv2d(
+ 384, 384, kernel_size=(1, 3), padding=(0, 1), conv_cfg=conv_cfg)
+ self.branch3x3dbl_3b = BasicConv2d(
+ 384, 384, kernel_size=(3, 1), padding=(1, 0), conv_cfg=conv_cfg)
+
+ self.branch_pool_downsample = nn.AvgPool2d(
+ kernel_size=3, stride=1, padding=1)
+ self.branch_pool = BasicConv2d(
+ in_channels, 192, kernel_size=1, conv_cfg=conv_cfg)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ branch_pool = self.branch_pool_downsample(x)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionAux(BaseModule):
+ """The Inception block for the auxiliary classification branch.
+
+ Args:
+ in_channels (int): The number of input channels.
+ num_classes (int): The number of categroies.
+ conv_cfg (dict, optional): The convolution layer config in the
+ :class:`BasicConv2d` block. Defaults to None.
+ init_cfg (dict, optional): The config of initialization.
+ Defaults to use trunc normal with ``std=0.01`` for Conv2d layers
+ and use trunc normal with ``std=0.001`` for Linear layers..
+ """
+
+ def __init__(self,
+ in_channels: int,
+ num_classes: int,
+ conv_cfg: Optional[dict] = None,
+ init_cfg: Optional[dict] = [
+ dict(type='TruncNormal', layer='Conv2d', std=0.01),
+ dict(type='TruncNormal', layer='Linear', std=0.001)
+ ]):
+ super().__init__(init_cfg=init_cfg)
+ self.downsample = nn.AvgPool2d(kernel_size=5, stride=3)
+ self.conv0 = BasicConv2d(
+ in_channels, 128, kernel_size=1, conv_cfg=conv_cfg)
+ self.conv1 = BasicConv2d(128, 768, kernel_size=5, conv_cfg=conv_cfg)
+ self.gap = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(768, num_classes)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ # N x 768 x 17 x 17
+ x = self.downsample(x)
+ # N x 768 x 5 x 5
+ x = self.conv0(x)
+ # N x 128 x 5 x 5
+ x = self.conv1(x)
+ # N x 768 x 1 x 1
+ # Adaptive average pooling
+ x = self.gap(x)
+ # N x 768 x 1 x 1
+ x = torch.flatten(x, 1)
+ # N x 768
+ x = self.fc(x)
+ # N x 1000
+ return x
+
+
+@MODELS.register_module()
+class InceptionV3(BaseBackbone):
+ """Inception V3 backbone.
+
+ A PyTorch implementation of `Rethinking the Inception Architecture for
+ Computer Vision `_
+
+ This implementation is modified from
+ https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py.
+ Licensed under the BSD 3-Clause License.
+
+ Args:
+ num_classes (int): The number of categroies. Defaults to 1000.
+ aux_logits (bool): Whether to enable the auxiliary branch. If False,
+ the auxiliary logits output will be None. Defaults to False.
+ dropout (float): Dropout rate. Defaults to 0.5.
+ init_cfg (dict, optional): The config of initialization. Defaults
+ to use trunc normal with ``std=0.1`` for all Conv2d and Linear
+ layers and constant with ``val=1`` for all BatchNorm2d layers.
+
+ Example:
+ >>> import torch
+ >>> from mmpretrain.models import build_backbone
+ >>>
+ >>> inputs = torch.rand(2, 3, 299, 299)
+ >>> cfg = dict(type='InceptionV3', num_classes=100)
+ >>> backbone = build_backbone(cfg)
+ >>> aux_out, out = backbone(inputs)
+ >>> # The auxiliary branch is disabled by default.
+ >>> assert aux_out is None
+ >>> print(out.shape)
+ torch.Size([2, 100])
+ >>> cfg = dict(type='InceptionV3', num_classes=100, aux_logits=True)
+ >>> backbone = build_backbone(cfg)
+ >>> aux_out, out = backbone(inputs)
+ >>> print(aux_out.shape, out.shape)
+ torch.Size([2, 100]) torch.Size([2, 100])
+ """
+
+ def __init__(
+ self,
+ num_classes: int = 1000,
+ aux_logits: bool = False,
+ dropout: float = 0.5,
+ init_cfg: Optional[dict] = [
+ dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=0.1),
+ dict(type='Constant', layer='BatchNorm2d', val=1)
+ ],
+ ) -> None:
+ super().__init__(init_cfg=init_cfg)
+
+ self.aux_logits = aux_logits
+ self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
+ self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
+ self.Mixed_5b = InceptionA(192, pool_features=32)
+ self.Mixed_5c = InceptionA(256, pool_features=64)
+ self.Mixed_5d = InceptionA(288, pool_features=64)
+ self.Mixed_6a = InceptionB(288)
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
+ self.AuxLogits: Optional[nn.Module] = None
+ if aux_logits:
+ self.AuxLogits = InceptionAux(768, num_classes)
+ self.Mixed_7a = InceptionD(768)
+ self.Mixed_7b = InceptionE(1280)
+ self.Mixed_7c = InceptionE(2048)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.dropout = nn.Dropout(p=dropout)
+ self.fc = nn.Linear(2048, num_classes)
+
+ def forward(
+ self,
+ x: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
+ """Forward function."""
+ # N x 3 x 299 x 299
+ x = self.Conv2d_1a_3x3(x)
+ # N x 32 x 149 x 149
+ x = self.Conv2d_2a_3x3(x)
+ # N x 32 x 147 x 147
+ x = self.Conv2d_2b_3x3(x)
+ # N x 64 x 147 x 147
+ x = self.maxpool1(x)
+ # N x 64 x 73 x 73
+ x = self.Conv2d_3b_1x1(x)
+ # N x 80 x 73 x 73
+ x = self.Conv2d_4a_3x3(x)
+ # N x 192 x 71 x 71
+ x = self.maxpool2(x)
+ # N x 192 x 35 x 35
+ x = self.Mixed_5b(x)
+ # N x 256 x 35 x 35
+ x = self.Mixed_5c(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_5d(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_6a(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6b(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6c(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6d(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6e(x)
+ # N x 768 x 17 x 17
+ aux: Optional[torch.Tensor] = None
+ if self.aux_logits and self.training:
+ aux = self.AuxLogits(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_7a(x)
+ # N x 1280 x 8 x 8
+ x = self.Mixed_7b(x)
+ # N x 2048 x 8 x 8
+ x = self.Mixed_7c(x)
+ # N x 2048 x 8 x 8
+ # Adaptive average pooling
+ x = self.avgpool(x)
+ # N x 2048 x 1 x 1
+ x = self.dropout(x)
+ # N x 2048 x 1 x 1
+ x = torch.flatten(x, 1)
+ # N x 2048
+ x = self.fc(x)
+ # N x 1000 (num_classes)
+ return aux, x
diff --git a/mmpretrain/models/backbones/lenet.py b/mmpretrain/models/backbones/lenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e423c0b15a60660714617e47fd68857b3a6d1e0
--- /dev/null
+++ b/mmpretrain/models/backbones/lenet.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+@MODELS.register_module()
+class LeNet5(BaseBackbone):
+ """`LeNet5 `_ backbone.
+
+ The input for LeNet-5 is a 32×32 grayscale image.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ The default value is -1, which uses the backbone as
+ a feature extractor without the top classifier.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(LeNet5, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(),
+ nn.AvgPool2d(kernel_size=2),
+ nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(),
+ nn.AvgPool2d(kernel_size=2),
+ nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh())
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(120, 84),
+ nn.Tanh(),
+ nn.Linear(84, num_classes),
+ )
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = self.classifier(x.squeeze())
+
+ return (x, )
diff --git a/mmpretrain/models/backbones/levit.py b/mmpretrain/models/backbones/levit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f7aa324e28b1725fb9e67110a26ea2d5c2831bd
--- /dev/null
+++ b/mmpretrain/models/backbones/levit.py
@@ -0,0 +1,522 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_activation_layer, fuse_conv_bn
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule, ModuleList, Sequential
+
+from mmpretrain.models.backbones.base_backbone import BaseBackbone
+from mmpretrain.registry import MODELS
+from ..utils import build_norm_layer
+
+
+class HybridBackbone(BaseModule):
+
+ def __init__(
+ self,
+ embed_dim,
+ kernel_size=3,
+ stride=2,
+ pad=1,
+ dilation=1,
+ groups=1,
+ act_cfg=dict(type='HSwish'),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None,
+ ):
+ super(HybridBackbone, self).__init__(init_cfg=init_cfg)
+
+ self.input_channels = [
+ 3, embed_dim // 8, embed_dim // 4, embed_dim // 2
+ ]
+ self.output_channels = [
+ embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim
+ ]
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.patch_embed = Sequential()
+
+ for i in range(len(self.input_channels)):
+ conv_bn = ConvolutionBatchNorm(
+ self.input_channels[i],
+ self.output_channels[i],
+ kernel_size=kernel_size,
+ stride=stride,
+ pad=pad,
+ dilation=dilation,
+ groups=groups,
+ norm_cfg=norm_cfg,
+ )
+ self.patch_embed.add_module('%d' % (2 * i), conv_bn)
+ if i < len(self.input_channels) - 1:
+ self.patch_embed.add_module('%d' % (i * 2 + 1),
+ build_activation_layer(act_cfg))
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ return x
+
+
+class ConvolutionBatchNorm(BaseModule):
+
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=2,
+ pad=1,
+ dilation=1,
+ groups=1,
+ norm_cfg=dict(type='BN'),
+ ):
+ super(ConvolutionBatchNorm, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channel,
+ out_channel,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ groups=groups,
+ bias=False)
+ self.bn = build_norm_layer(norm_cfg, out_channel)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+ @torch.no_grad()
+ def fuse(self):
+ return fuse_conv_bn(self).conv
+
+
+class LinearBatchNorm(BaseModule):
+
+ def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')):
+ super(LinearBatchNorm, self).__init__()
+ self.linear = nn.Linear(in_feature, out_feature, bias=False)
+ self.bn = build_norm_layer(norm_cfg, out_feature)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.bn(x.flatten(0, 1)).reshape_as(x)
+ return x
+
+ @torch.no_grad()
+ def fuse(self):
+ w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
+ w = self.linear.weight * w[:, None]
+ b = self.bn.bias - self.bn.running_mean * self.bn.weight / \
+ (self.bn.running_var + self.bn.eps) ** 0.5
+
+ factory_kwargs = {
+ 'device': self.linear.weight.device,
+ 'dtype': self.linear.weight.dtype
+ }
+ bias = nn.Parameter(
+ torch.empty(self.linear.out_features, **factory_kwargs))
+ self.linear.register_parameter('bias', bias)
+ self.linear.weight.data.copy_(w)
+ self.linear.bias.data.copy_(b)
+ return self.linear
+
+
+class Residual(BaseModule):
+
+ def __init__(self, block, drop_path_rate=0.):
+ super(Residual, self).__init__()
+ self.block = block
+ if drop_path_rate > 0:
+ self.drop_path = DropPath(drop_path_rate)
+ else:
+ self.drop_path = nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.block(x))
+ return x
+
+
+class Attention(BaseModule):
+
+ def __init__(
+ self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4,
+ act_cfg=dict(type='HSwish'),
+ resolution=14,
+ ):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+ h = self.dh + nh_kd * 2
+ self.qkv = LinearBatchNorm(dim, h)
+ self.proj = nn.Sequential(
+ build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim))
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = torch.nn.Parameter(
+ torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs',
+ torch.LongTensor(idxs).view(N, N))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ """change the mode of model."""
+ super(Attention, self).train(mode)
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x): # x (B,N,C)
+ B, N, C = x.shape # 2 196 128
+ qkv = self.qkv(x) # 2 196 128
+ q, k, v = qkv.view(B, N, self.num_heads, -1).split(
+ [self.key_dim, self.key_dim, self.d],
+ dim=3) # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32
+ q = q.permute(0, 2, 1, 3) # 2 4 196 16
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ attn = ((q @ k.transpose(-2, -1)) *
+ self.scale # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196
+ + (self.attention_biases[:, self.attention_bias_idxs]
+ if self.training else self.ab))
+ attn = attn.softmax(dim=-1) # 2 4 196 196 -> 2 4 196 196
+ x = (attn @ v).transpose(1, 2).reshape(
+ B, N,
+ self.dh) # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128
+ x = self.proj(x)
+ return x
+
+
+class MLP(nn.Sequential):
+
+ def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')):
+ super(MLP, self).__init__()
+ h = embed_dim * mlp_ratio
+ self.linear1 = LinearBatchNorm(embed_dim, h)
+ self.activation = build_activation_layer(act_cfg)
+ self.linear2 = LinearBatchNorm(h, embed_dim)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.activation(x)
+ x = self.linear2(x)
+ return x
+
+
+class Subsample(BaseModule):
+
+ def __init__(self, stride, resolution):
+ super(Subsample, self).__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, x):
+ B, _, C = x.shape
+ # B, N, C -> B, H, W, C
+ x = x.view(B, self.resolution, self.resolution, C)
+ x = x[:, ::self.stride, ::self.stride]
+ x = x.reshape(B, -1, C) # B, H', W', C -> B, N', C
+ return x
+
+
+class AttentionSubsample(nn.Sequential):
+
+ def __init__(self,
+ in_dim,
+ out_dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=2,
+ act_cfg=dict(type='HSwish'),
+ stride=2,
+ resolution=14):
+ super(AttentionSubsample, self).__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * self.num_heads
+ self.attn_ratio = attn_ratio
+ self.sub_resolution = (resolution - 1) // stride + 1
+ h = self.dh + nh_kd
+ self.kv = LinearBatchNorm(in_dim, h)
+
+ self.q = nn.Sequential(
+ Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd))
+ self.proj = nn.Sequential(
+ build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim))
+
+ self.stride = stride
+ self.resolution = resolution
+ points = list(itertools.product(range(resolution), range(resolution)))
+ sub_points = list(
+ itertools.product(
+ range(self.sub_resolution), range(self.sub_resolution)))
+ N = len(points)
+ N_sub = len(sub_points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in sub_points:
+ for p2 in points:
+ size = 1
+ offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
+ abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = torch.nn.Parameter(
+ torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs',
+ torch.LongTensor(idxs).view(N_sub, N))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super(AttentionSubsample, self).train(mode)
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x):
+ B, N, C = x.shape
+ k, v = self.kv(x).view(B, N, self.num_heads,
+ -1).split([self.key_dim, self.d], dim=3)
+ k = k.permute(0, 2, 1, 3) # BHNC
+ v = v.permute(0, 2, 1, 3) # BHNC
+ q = self.q(x).view(B, self.sub_resolution**2, self.num_heads,
+ self.key_dim).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale + \
+ (self.attention_biases[:, self.attention_bias_idxs]
+ if self.training else self.ab)
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
+ x = self.proj(x)
+ return x
+
+
+@MODELS.register_module()
+class LeViT(BaseBackbone):
+ """LeViT backbone.
+
+ A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's
+ Clothing for Faster Inference `_
+
+ Modified from the official implementation:
+ https://github.com/facebookresearch/LeViT
+
+ Args:
+ arch (str | dict): LeViT architecture.
+
+ If use string, choose from '128s', '128', '192', '256' and '384'.
+ If use dict, it should have below keys:
+
+ - **embed_dims** (List[int]): The embed dimensions of each stage.
+ - **key_dims** (List[int]): The embed dimensions of the key in the
+ attention layers of each stage.
+ - **num_heads** (List[int]): The number of heads in each stage.
+ - **depths** (List[int]): The number of blocks in each stage.
+
+ img_size (int): Input image size
+ patch_size (int | tuple): The patch size. Deault to 16
+ attn_ratio (int): Ratio of hidden dimensions of the value in attention
+ layers. Defaults to 2.
+ mlp_ratio (int): Ratio of hidden dimensions in MLP layers.
+ Defaults to 2.
+ act_cfg (dict): The config of activation functions.
+ Defaults to ``dict(type='HSwish')``.
+ hybrid_backbone (callable): A callable object to build the patch embed
+ module. Defaults to use :class:`HybridBackbone`.
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ deploy (bool): Whether to switch the model structure to
+ deployment mode. Defaults to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ '128s': {
+ 'embed_dims': [128, 256, 384],
+ 'num_heads': [4, 6, 8],
+ 'depths': [2, 3, 4],
+ 'key_dims': [16, 16, 16],
+ },
+ '128': {
+ 'embed_dims': [128, 256, 384],
+ 'num_heads': [4, 8, 12],
+ 'depths': [4, 4, 4],
+ 'key_dims': [16, 16, 16],
+ },
+ '192': {
+ 'embed_dims': [192, 288, 384],
+ 'num_heads': [3, 5, 6],
+ 'depths': [4, 4, 4],
+ 'key_dims': [32, 32, 32],
+ },
+ '256': {
+ 'embed_dims': [256, 384, 512],
+ 'num_heads': [4, 6, 8],
+ 'depths': [4, 4, 4],
+ 'key_dims': [32, 32, 32],
+ },
+ '384': {
+ 'embed_dims': [384, 512, 768],
+ 'num_heads': [6, 9, 12],
+ 'depths': [4, 4, 4],
+ 'key_dims': [32, 32, 32],
+ },
+ }
+
+ def __init__(self,
+ arch,
+ img_size=224,
+ patch_size=16,
+ attn_ratio=2,
+ mlp_ratio=2,
+ act_cfg=dict(type='HSwish'),
+ hybrid_backbone=HybridBackbone,
+ out_indices=-1,
+ deploy=False,
+ drop_path_rate=0,
+ init_cfg=None):
+ super(LeViT, self).__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch = self.arch_zoo[arch]
+ elif isinstance(arch, dict):
+ essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch = arch
+ else:
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ self.embed_dims = self.arch['embed_dims']
+ self.num_heads = self.arch['num_heads']
+ self.key_dims = self.arch['key_dims']
+ self.depths = self.arch['depths']
+ self.num_stages = len(self.embed_dims)
+ self.drop_path_rate = drop_path_rate
+
+ self.patch_embed = hybrid_backbone(self.embed_dims[0])
+
+ self.resolutions = []
+ resolution = img_size // patch_size
+ self.stages = ModuleList()
+ for i, (embed_dims, key_dims, depth, num_heads) in enumerate(
+ zip(self.embed_dims, self.key_dims, self.depths,
+ self.num_heads)):
+ blocks = []
+ if i > 0:
+ downsample = AttentionSubsample(
+ in_dim=self.embed_dims[i - 1],
+ out_dim=embed_dims,
+ key_dim=key_dims,
+ num_heads=self.embed_dims[i - 1] // key_dims,
+ attn_ratio=4,
+ act_cfg=act_cfg,
+ stride=2,
+ resolution=resolution)
+ blocks.append(downsample)
+ resolution = downsample.sub_resolution
+ if mlp_ratio > 0: # mlp_ratio
+ blocks.append(
+ Residual(
+ MLP(embed_dims, mlp_ratio, act_cfg=act_cfg),
+ self.drop_path_rate))
+ self.resolutions.append(resolution)
+ for _ in range(depth):
+ blocks.append(
+ Residual(
+ Attention(
+ embed_dims,
+ key_dims,
+ num_heads,
+ attn_ratio=attn_ratio,
+ act_cfg=act_cfg,
+ resolution=resolution,
+ ), self.drop_path_rate))
+ if mlp_ratio > 0:
+ blocks.append(
+ Residual(
+ MLP(embed_dims, mlp_ratio, act_cfg=act_cfg),
+ self.drop_path_rate))
+
+ self.stages.append(Sequential(*blocks))
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ elif isinstance(out_indices, tuple):
+ out_indices = list(out_indices)
+ elif not isinstance(out_indices, list):
+ raise TypeError('"out_indices" must by a list, tuple or int, '
+ f'get {type(out_indices)} instead.')
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_stages + index
+ assert 0 <= out_indices[i] < self.num_stages, \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ self.deploy = False
+ if deploy:
+ self.switch_to_deploy()
+
+ def switch_to_deploy(self):
+ if self.deploy:
+ return
+ fuse_parameters(self)
+ self.deploy = True
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ x = x.flatten(2).transpose(1, 2) # B, C, H, W -> B, L, C
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x = stage(x)
+ B, _, C = x.shape
+ if i in self.out_indices:
+ out = x.reshape(B, self.resolutions[i], self.resolutions[i], C)
+ out = out.permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
+
+
+def fuse_parameters(module):
+ for child_name, child in module.named_children():
+ if hasattr(child, 'fuse'):
+ setattr(module, child_name, child.fuse())
+ else:
+ fuse_parameters(child)
diff --git a/mmpretrain/models/backbones/mixmim.py b/mmpretrain/models/backbones/mixmim.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c67aa0c3a45c5c85adbacb94ae90dc170b2d0bb
--- /dev/null
+++ b/mmpretrain/models/backbones/mixmim.py
@@ -0,0 +1,533 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.drop import DropPath
+from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging
+from mmengine.model import BaseModule
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from mmpretrain.registry import MODELS
+from ..utils import WindowMSA, to_2tuple
+from .base_backbone import BaseBackbone
+from .vision_transformer import TransformerEncoderLayer
+
+
+class MixMIMWindowAttention(WindowMSA):
+ """MixMIM Window Attention.
+
+ Compared with WindowMSA, we add some modifications
+ in ``forward`` to meet the requirement of MixMIM during
+ pretraining.
+
+ Implements one windown attention in MixMIM.
+ Args:
+ embed_dims (int): The feature dimension.
+ window_size (list): The height and width of the window.
+ num_heads (int): The number of head in attention.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__(
+ embed_dims=embed_dims,
+ window_size=window_size,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop_rate,
+ proj_drop=proj_drop_rate,
+ init_cfg=init_cfg)
+
+ def forward(self, x, mask=None):
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ mask = mask.reshape(B_, 1, 1, N)
+ mask_new = mask * mask.transpose(
+ 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3)
+ mask_new = 1 - mask_new
+
+ if mask_new.dtype == torch.float16:
+ attn = attn - 65500 * mask_new
+ else:
+ attn = attn - 1e30 * mask_new
+
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MixMIMBlock(TransformerEncoderLayer):
+ """MixMIM Block. Implements one block in MixMIM.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ input_resolution (tuple): Input resolution of this layer.
+ num_heads (int): The number of head in attention,
+ window_size (list): The height and width of the window.
+ mlp_ratio (int): The MLP ration in FFN.
+ num_fcs (int): The number of linear layers in a block.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate.
+ Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ num_fcs=2,
+ qkv_bias=True,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
+
+ super().__init__(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=int(mlp_ratio * embed_dims),
+ drop_rate=proj_drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ num_fcs=num_fcs,
+ qkv_bias=qkv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+
+ if min(self.input_resolution) <= self.window_size:
+ self.window_size = min(self.input_resolution)
+
+ self.attn = MixMIMWindowAttention(
+ embed_dims=embed_dims,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate)
+
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ @staticmethod
+ def window_reverse(windows, H, W, window_size):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ @staticmethod
+ def window_partition(x, window_size):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+ def forward(self, x, attn_mask=None):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+
+ shortcut = x
+ x = self.ln1(x)
+ x = x.view(B, H, W, C)
+
+ # partition windows
+ x_windows = self.window_partition(
+ x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
+ C) # nW*B, window_size*window_size, C
+ if attn_mask is not None:
+ attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1
+ attn_mask = attn_mask.view(B, H, W, 1)
+ attn_mask = self.window_partition(attn_mask, self.window_size)
+ attn_mask = attn_mask.view(-1, self.window_size * self.window_size,
+ 1)
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+ x = self.window_reverse(attn_windows, H, W,
+ self.window_size) # B H' W' C
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + self.drop_path(x)
+
+ x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath
+
+ return x
+
+
+class MixMIMLayer(BaseModule):
+ """Implements one MixMIM layer, which may contains several MixMIM blocks.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ input_resolution (tuple): Input resolution of this layer.
+ depth (int): The number of blocks in this layer.
+ num_heads (int): The number of head in attention,
+ window_size (list): The height and width of the window.
+ mlp_ratio (int): The MLP ration in FFN.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate.
+ Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ downsample (class, optional): Downsample the output of blocks b
+ y patch merging.Defaults to None.
+ use_checkpoint (bool): Whether use the checkpoint to
+ reduce GPU memory cost.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims: int,
+ input_resolution: int,
+ depth: int,
+ num_heads: int,
+ window_size: int,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=[0.],
+ norm_cfg=dict(type='LN'),
+ downsample=None,
+ use_checkpoint=False,
+ init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ self.blocks.append(
+ MixMIMBlock(
+ embed_dims=embed_dims,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop_rate=proj_drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate[i],
+ norm_cfg=norm_cfg))
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ in_channels=embed_dims,
+ out_channels=2 * embed_dims,
+ norm_cfg=norm_cfg)
+ else:
+ self.downsample = None
+
+ def forward(self, x, attn_mask=None):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask=attn_mask)
+ if self.downsample is not None:
+ x, _ = self.downsample(x, self.input_resolution)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.embed_dims}, \
+ input_resolution={self.input_resolution}, depth={self.depth}'
+
+
+@MODELS.register_module()
+class MixMIMTransformer(BaseBackbone):
+ """MixMIM backbone.
+
+ A PyTorch implement of : ` MixMIM: Mixed and Masked Image
+ Modeling for Efficient Visual Representation Learning
+ `_
+
+ Args:
+ arch (str | dict): MixMIM architecture. If use string,
+ choose from 'base','large' and 'huge'.
+ If use dict, it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **depths** (int): The number of transformer encoder layers.
+ - **num_heads** (int): The number of heads in attention modules.
+
+ Defaults to 'base'.
+ mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
+ img_size (int | tuple): The expected input image shape. Because we
+ support dynamic input shape, just set the argument to mlp_ratio
+ the most common input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ in_channels (int): The num of input channels. Defaults to 3.
+ window_size (list): The height and width of the window.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ patch_cfg (dict): Extra config dict for patch embedding.
+ Defaults to an empty dict.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ attn_drop_rate (float): attention drop rate. Defaults to 0.
+ use_checkpoint (bool): Whether use the checkpoint to
+ reduce GPU memory cost.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(
+ ['b', 'base'], {
+ 'embed_dims': 128,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [4, 8, 16, 32]
+ }),
+ **dict.fromkeys(
+ ['l', 'large'], {
+ 'embed_dims': 192,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [6, 12, 24, 48]
+ }),
+ **dict.fromkeys(
+ ['h', 'huge'], {
+ 'embed_dims': 352,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [11, 22, 44, 88]
+ }),
+ }
+
+ def __init__(
+ self,
+ arch='base',
+ mlp_ratio=4,
+ img_size=224,
+ patch_size=4,
+ in_channels=3,
+ window_size=[14, 14, 14, 7],
+ qkv_bias=True,
+ patch_cfg=dict(),
+ norm_cfg=dict(type='LN'),
+ drop_rate=0.0,
+ drop_path_rate=0.0,
+ attn_drop_rate=0.0,
+ use_checkpoint=False,
+ init_cfg: Optional[dict] = None,
+ ) -> None:
+ super(MixMIMTransformer, self).__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'embed_dims', 'depths', 'num_heads'}
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+
+ self.encoder_stride = 32
+
+ self.num_layers = len(self.depths)
+ self.qkv_bias = qkv_bias
+ self.drop_rate = drop_rate
+ self.attn_drop_rate = attn_drop_rate
+ self.use_checkpoint = use_checkpoint
+ self.mlp_ratio = mlp_ratio
+ self.window_size = window_size
+
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ norm_cfg=dict(type='LN'),
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+
+ self.dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(self.depths))
+ ]
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ self.layers.append(
+ MixMIMLayer(
+ embed_dims=int(self.embed_dims * 2**i_layer),
+ input_resolution=(self.patch_resolution[0] // (2**i_layer),
+ self.patch_resolution[1] //
+ (2**i_layer)),
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ proj_drop_rate=self.drop_rate,
+ attn_drop_rate=self.attn_drop_rate,
+ drop_path_rate=self.dpr[sum(self.depths[:i_layer]
+ ):sum(self.depths[:i_layer +
+ 1])],
+ norm_cfg=norm_cfg,
+ downsample=PatchMerging if
+ (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=self.use_checkpoint))
+
+ self.num_features = int(self.embed_dims * 2**(self.num_layers - 1))
+ self.drop_after_pos = nn.Dropout(p=self.drop_rate)
+
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, self.num_patches, self.embed_dims),
+ requires_grad=False)
+
+ _, self.norm = build_norm_layer(norm_cfg, self.num_features)
+
+ def forward(self, x: torch.Tensor):
+ x, _ = self.patch_embed(x)
+
+ x = x + self.absolute_pos_embed
+ x = self.drop_after_pos(x)
+
+ for layer in self.layers:
+ x = layer(x, attn_mask=None)
+
+ x = self.norm(x)
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+
+ return (x, )
+
+ def get_layer_depth(self, param_name: str, prefix: str = ''):
+ """Get the layer-wise depth of a parameter.
+
+ Args:
+ param_name (str): The name of the parameter.
+ prefix (str): The prefix for the parameter.
+ Defaults to an empty string.
+
+ Returns:
+ Tuple[int, int]: The layer-wise depth and the num of layers.
+
+ Note:
+ The first depth is the stem module (``layer_depth=0``), and the
+ last depth is the subsequent module (``layer_depth=num_layers-1``)
+ """
+ num_layers = sum(self.depths) + 2
+
+ if not param_name.startswith(prefix):
+ # For subsequent module like neck and head
+ if param_name.startswith('neck'):
+ return num_layers - 2, num_layers
+ else:
+ return num_layers - 1, num_layers
+
+ param_name = param_name[len(prefix):]
+
+ stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed')
+ if any(stem in param_name for stem in stem_layers):
+ layer_depth = 0
+ elif param_name.startswith('layers'):
+ layer_id = int(param_name.split('.')[1])
+ block_id = param_name.split('.')[3]
+
+ if block_id in ('downsample', 'reduction', 'norm'):
+ layer_depth = sum(self.depths[:layer_id + 1])
+ else:
+ layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1
+ else:
+ layer_depth = num_layers - 2
+
+ return layer_depth, num_layers
diff --git a/mmpretrain/models/backbones/mlp_mixer.py b/mmpretrain/models/backbones/mlp_mixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..26fb8ce0186c2451a5698c413ebf2bc24f33b6ec
--- /dev/null
+++ b/mmpretrain/models/backbones/mlp_mixer.py
@@ -0,0 +1,263 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence
+
+import torch.nn as nn
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
+from mmengine.model import BaseModule, ModuleList
+
+from mmpretrain.registry import MODELS
+from ..utils import to_2tuple
+from .base_backbone import BaseBackbone
+
+
+class MixerBlock(BaseModule):
+ """Mlp-Mixer basic block.
+
+ Basic module of `MLP-Mixer: An all-MLP Architecture for Vision
+ `_
+
+ Args:
+ num_tokens (int): The number of patched tokens
+ embed_dims (int): The feature dimension
+ tokens_mlp_dims (int): The hidden dimension for tokens FFNs
+ channels_mlp_dims (int): The hidden dimension for channels FFNs
+ drop_rate (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Defaults to 2.
+ act_cfg (dict): The activation config for FFNs.
+ Defaults to ``dict(type='GELU')``.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ num_tokens,
+ embed_dims,
+ tokens_mlp_dims,
+ channels_mlp_dims,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ num_fcs=2,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super(MixerBlock, self).__init__(init_cfg=init_cfg)
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, embed_dims, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.token_mix = FFN(
+ embed_dims=num_tokens,
+ feedforward_channels=tokens_mlp_dims,
+ num_fcs=num_fcs,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ add_identity=False)
+
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, embed_dims, postfix=2)
+ self.add_module(self.norm2_name, norm2)
+ self.channel_mix = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=channels_mlp_dims,
+ num_fcs=num_fcs,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ def init_weights(self):
+ super(MixerBlock, self).init_weights()
+ for m in self.token_mix.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.bias, std=1e-6)
+ for m in self.channel_mix.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.bias, std=1e-6)
+
+ def forward(self, x):
+ out = self.norm1(x).transpose(1, 2)
+ x = x + self.token_mix(out).transpose(1, 2)
+ x = self.channel_mix(self.norm2(x), identity=x)
+ return x
+
+
+@MODELS.register_module()
+class MlpMixer(BaseBackbone):
+ """Mlp-Mixer backbone.
+
+ Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision
+ `_
+
+ Args:
+ arch (str | dict): MLP Mixer architecture. If use string, choose from
+ 'small', 'base' and 'large'. If use dict, it should have below
+ keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **num_layers** (int): The number of MLP blocks.
+ - **tokens_mlp_dims** (int): The hidden dimensions for tokens FFNs.
+ - **channels_mlp_dims** (int): The The hidden dimensions for
+ channels FFNs.
+
+ Defaults to 'base'.
+ img_size (int | tuple): The input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ out_indices (Sequence | int): Output from which layer.
+ Defaults to -1, means the last layer.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ act_cfg (dict): The activation config for FFNs. Default GELU.
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each mixer block layer.
+ Defaults to an empty dict.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ arch_zoo = {
+ **dict.fromkeys(
+ ['s', 'small'], {
+ 'embed_dims': 512,
+ 'num_layers': 8,
+ 'tokens_mlp_dims': 256,
+ 'channels_mlp_dims': 2048,
+ }),
+ **dict.fromkeys(
+ ['b', 'base'], {
+ 'embed_dims': 768,
+ 'num_layers': 12,
+ 'tokens_mlp_dims': 384,
+ 'channels_mlp_dims': 3072,
+ }),
+ **dict.fromkeys(
+ ['l', 'large'], {
+ 'embed_dims': 1024,
+ 'num_layers': 24,
+ 'tokens_mlp_dims': 512,
+ 'channels_mlp_dims': 4096,
+ }),
+ }
+
+ def __init__(self,
+ arch='base',
+ img_size=224,
+ patch_size=16,
+ out_indices=-1,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ patch_cfg=dict(),
+ layer_cfgs=dict(),
+ init_cfg=None):
+ super(MlpMixer, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'num_layers', 'tokens_mlp_dims',
+ 'channels_mlp_dims'
+ }
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.num_layers = self.arch_settings['num_layers']
+ self.tokens_mlp_dims = self.arch_settings['tokens_mlp_dims']
+ self.channels_mlp_dims = self.arch_settings['channels_mlp_dims']
+
+ self.img_size = to_2tuple(img_size)
+
+ _patch_cfg = dict(
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must be a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_layers + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ else:
+ assert index >= self.num_layers, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ self.layers = ModuleList()
+ if isinstance(layer_cfgs, dict):
+ layer_cfgs = [layer_cfgs] * self.num_layers
+ for i in range(self.num_layers):
+ _layer_cfg = dict(
+ num_tokens=num_patches,
+ embed_dims=self.embed_dims,
+ tokens_mlp_dims=self.tokens_mlp_dims,
+ channels_mlp_dims=self.channels_mlp_dims,
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ )
+ _layer_cfg.update(layer_cfgs[i])
+ self.layers.append(MixerBlock(**_layer_cfg))
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.embed_dims, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ assert x.shape[2:] == self.img_size, \
+ "The MLP-Mixer doesn't support dynamic input shape. " \
+ f'Please input images with shape {self.img_size}'
+ x, _ = self.patch_embed(x)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1:
+ x = self.norm1(x)
+
+ if i in self.out_indices:
+ out = x.transpose(1, 2)
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/mobilenet_v2.py b/mmpretrain/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca1418a13c4ed81c4666e7f53b0417c36b2e99b
--- /dev/null
+++ b/mmpretrain/models/backbones/mobilenet_v2.py
@@ -0,0 +1,264 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpretrain.models.utils import make_divisible
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class InvertedResidual(BaseModule):
+ """InvertedResidual block for MobileNetV2.
+
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False,
+ init_cfg=None):
+ super(InvertedResidual, self).__init__(init_cfg)
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+@MODELS.register_module()
+class MobileNetV2(BaseBackbone):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ # Parameters to build layers. 4 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks, stride.
+ arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
+ [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
+ [6, 320, 1, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ out_indices=(7, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(MobileNetV2, self).__init__(init_cfg)
+ self.widen_factor = widen_factor
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 8):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+
+ if frozen_stages not in range(-1, 8):
+ raise ValueError('frozen_stages must be in range(-1, 8). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks, stride = layer_cfg
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ if widen_factor > 1.0:
+ self.out_channel = int(1280 * widen_factor)
+ else:
+ self.out_channel = 1280
+
+ layer = ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.add_module('conv2', layer)
+ self.layers.append('conv2')
+
+ def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio. Default: 6.
+ """
+ layers = []
+ for i in range(num_blocks):
+ if i >= 1:
+ stride = 1
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride,
+ expand_ratio=expand_ratio,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpretrain/models/backbones/mobilenet_v3.py b/mmpretrain/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..577dba94040dec5ecda9388819b8b5205f307dce
--- /dev/null
+++ b/mmpretrain/models/backbones/mobilenet_v3.py
@@ -0,0 +1,217 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import ConvModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from ..utils import InvertedResidual
+from .base_backbone import BaseBackbone
+
+
+@MODELS.register_module()
+class MobileNetV3(BaseBackbone):
+ """MobileNetV3 backbone.
+
+ Args:
+ arch (str): Architecture of mobilnetv3, from {small, large}.
+ Default: small.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: None, which means output tensors from final stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2],
+ [3, 72, 24, False, 'ReLU', 2],
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1],
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2],
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'small_075': [[3, 16, 16, True, 'ReLU', 2],
+ [3, 72, 24, False, 'ReLU', 2],
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 32, True, 'HSwish', 2],
+ [5, 192, 32, True, 'HSwish', 1],
+ [5, 192, 32, True, 'HSwish', 1],
+ [5, 96, 40, True, 'HSwish', 1],
+ [5, 120, 40, True, 'HSwish', 1],
+ [5, 240, 72, True, 'HSwish', 2],
+ [5, 432, 72, True, 'HSwish', 1],
+ [5, 432, 72, True, 'HSwish', 1]],
+ 'small_050': [[3, 16, 8, True, 'ReLU', 2],
+ [3, 40, 16, False, 'ReLU', 2],
+ [3, 56, 16, False, 'ReLU', 1],
+ [5, 64, 24, True, 'HSwish', 2],
+ [5, 144, 24, True, 'HSwish', 1],
+ [5, 144, 24, True, 'HSwish', 1],
+ [5, 72, 24, True, 'HSwish', 1],
+ [5, 72, 24, True, 'HSwish', 1],
+ [5, 144, 48, True, 'HSwish', 2],
+ [5, 288, 48, True, 'HSwish', 1],
+ [5, 288, 48, True, 'HSwish', 1]],
+ 'large': [[3, 16, 16, False, 'ReLU', 1],
+ [3, 64, 24, False, 'ReLU', 2],
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2],
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2],
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1],
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2],
+ [5, 960, 160, True, 'HSwish', 1],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', eps=0.001, momentum=0.01),
+ out_indices=None,
+ frozen_stages=-1,
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(
+ type='Kaiming',
+ layer=['Conv2d'],
+ nonlinearity='leaky_relu'),
+ dict(type='Normal', layer=['Linear'], std=0.01),
+ dict(type='Constant', layer=['BatchNorm2d'], val=1)
+ ]):
+ super(MobileNetV3, self).__init__(init_cfg)
+ assert arch in self.arch_settings
+ if out_indices is None:
+ out_indices = (12, ) if 'small' in arch else (16, )
+ for order, index in enumerate(out_indices):
+ if index not in range(0, len(self.arch_settings[arch]) + 2):
+ raise ValueError(
+ 'the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch]) + 2}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch]) + 2}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.layers = self._make_layer()
+ self.feat_dim = self.arch_settings[arch][-1][1]
+
+ def _make_layer(self):
+ layers = []
+ layer_setting = self.arch_settings[self.arch]
+ in_channels = 16
+
+ layer = ConvModule(
+ in_channels=3,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ self.add_module('layer0', layer)
+ layers.append('layer0')
+
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(
+ type='HSigmoid',
+ bias=3,
+ divisor=6,
+ min_value=0,
+ max_value=1)))
+ else:
+ se_cfg = None
+
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ in_channels = out_channels
+ layer_name = 'layer{}'.format(i + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # Build the last layer before pooling
+ # TODO: No dilation
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ layer_name = 'layer{}'.format(len(layer_setting) + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ return layers
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(0, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV3, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmpretrain/models/backbones/mobileone.py b/mmpretrain/models/backbones/mobileone.py
new file mode 100644
index 0000000000000000000000000000000000000000..1111441af82d43a49d15ecbb5dc0778fc9f87596
--- /dev/null
+++ b/mmpretrain/models/backbones/mobileone.py
@@ -0,0 +1,515 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501
+from typing import Optional, Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
+from mmengine.model import BaseModule, ModuleList, Sequential
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from ..utils.se_layer import SELayer
+from .base_backbone import BaseBackbone
+
+
+class MobileOneBlock(BaseModule):
+ """MobileOne block for MobileOne backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ out_channels (int): The output channels of the block.
+ kernel_size (int): The kernel size of the convs in the block. If the
+ kernel size is large than 1, there will be a ``branch_scale`` in
+ the block.
+ num_convs (int): Number of the convolution branches in the block.
+ stride (int): Stride of convolution layers. Defaults to 1.
+ padding (int): Padding of the convolution layers. Defaults to 1.
+ dilation (int): Dilation of the convolution layers. Defaults to 1.
+ groups (int): Groups of the convolution layers. Defaults to 1.
+ se_cfg (None or dict): The configuration of the se module.
+ Defaults to None.
+ norm_cfg (dict): Configuration to construct and config norm layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='ReLU')``.
+ deploy (bool): Whether the model structure is in the deployment mode.
+ Defaults to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ num_convs: int,
+ stride: int = 1,
+ padding: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ se_cfg: Optional[dict] = None,
+ conv_cfg: Optional[dict] = None,
+ norm_cfg: Optional[dict] = dict(type='BN'),
+ act_cfg: Optional[dict] = dict(type='ReLU'),
+ deploy: bool = False,
+ init_cfg: Optional[dict] = None):
+ super(MobileOneBlock, self).__init__(init_cfg)
+
+ assert se_cfg is None or isinstance(se_cfg, dict)
+ if se_cfg is not None:
+ self.se = SELayer(channels=out_channels, **se_cfg)
+ else:
+ self.se = nn.Identity()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.num_conv_branches = num_convs
+ self.stride = stride
+ self.padding = padding
+ self.se_cfg = se_cfg
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.deploy = deploy
+ self.groups = groups
+ self.dilation = dilation
+
+ if deploy:
+ self.branch_reparam = build_conv_layer(
+ conv_cfg,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ groups=self.groups,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=True)
+ else:
+ # judge if input shape and output shape are the same.
+ # If true, add a normalized identity shortcut.
+ if out_channels == in_channels and stride == 1:
+ self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
+ else:
+ self.branch_norm = None
+
+ self.branch_scale = None
+ if kernel_size > 1:
+ self.branch_scale = self.create_conv_bn(kernel_size=1)
+
+ self.branch_conv_list = ModuleList()
+ for _ in range(num_convs):
+ self.branch_conv_list.append(
+ self.create_conv_bn(
+ kernel_size=kernel_size,
+ padding=padding,
+ dilation=dilation))
+
+ self.act = build_activation_layer(act_cfg)
+
+ def create_conv_bn(self, kernel_size, dilation=1, padding=0):
+ """cearte a (conv + bn) Sequential layer."""
+ conv_bn = Sequential()
+ conv_bn.add_module(
+ 'conv',
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=kernel_size,
+ groups=self.groups,
+ stride=self.stride,
+ dilation=dilation,
+ padding=padding,
+ bias=False))
+ conv_bn.add_module(
+ 'norm',
+ build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1])
+
+ return conv_bn
+
+ def forward(self, x):
+
+ def _inner_forward(inputs):
+ if self.deploy:
+ return self.branch_reparam(inputs)
+
+ inner_out = 0
+ if self.branch_norm is not None:
+ inner_out = self.branch_norm(inputs)
+
+ if self.branch_scale is not None:
+ inner_out += self.branch_scale(inputs)
+
+ for branch_conv in self.branch_conv_list:
+ inner_out += branch_conv(inputs)
+
+ return inner_out
+
+ return self.act(self.se(_inner_forward(x)))
+
+ def switch_to_deploy(self):
+ """Switch the model structure from training mode to deployment mode."""
+ if self.deploy:
+ return
+ assert self.norm_cfg['type'] == 'BN', \
+ "Switch is not allowed when norm_cfg['type'] != 'BN'."
+
+ reparam_weight, reparam_bias = self.reparameterize()
+ self.branch_reparam = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.out_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ bias=True)
+ self.branch_reparam.weight.data = reparam_weight
+ self.branch_reparam.bias.data = reparam_bias
+
+ for param in self.parameters():
+ param.detach_()
+ delattr(self, 'branch_conv_list')
+ if hasattr(self, 'branch_scale'):
+ delattr(self, 'branch_scale')
+ delattr(self, 'branch_norm')
+
+ self.deploy = True
+
+ def reparameterize(self):
+ """Fuse all the parameters of all branches.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
+ branches. the first element is the weights and the second is
+ the bias.
+ """
+ weight_conv, bias_conv = 0, 0
+ for branch_conv in self.branch_conv_list:
+ weight, bias = self._fuse_conv_bn(branch_conv)
+ weight_conv += weight
+ bias_conv += bias
+
+ weight_scale, bias_scale = 0, 0
+ if self.branch_scale is not None:
+ weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale)
+ # Pad scale branch kernel to match conv branch kernel size.
+ pad = self.kernel_size // 2
+ weight_scale = F.pad(weight_scale, [pad, pad, pad, pad])
+
+ weight_norm, bias_norm = 0, 0
+ if self.branch_norm:
+ tmp_conv_bn = self._norm_to_conv(self.branch_norm)
+ weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn)
+
+ return (weight_conv + weight_scale + weight_norm,
+ bias_conv + bias_scale + bias_norm)
+
+ def _fuse_conv_bn(self, branch):
+ """Fuse the parameters in a branch with a conv and bn.
+
+ Args:
+ branch (mmcv.runner.Sequential): A branch with conv and bn.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
+ fusing the parameters of conv and bn in one branch.
+ The first element is the weight and the second is the bias.
+ """
+ if branch is None:
+ return 0, 0
+ kernel = branch.conv.weight
+ running_mean = branch.norm.running_mean
+ running_var = branch.norm.running_var
+ gamma = branch.norm.weight
+ beta = branch.norm.bias
+ eps = branch.norm.eps
+
+ std = (running_var + eps).sqrt()
+ fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel
+ fused_bias = beta - running_mean * gamma / std
+
+ return fused_weight, fused_bias
+
+ def _norm_to_conv(self, branch_nrom):
+ """Convert a norm layer to a conv-bn sequence towards
+ ``self.kernel_size``.
+
+ Args:
+ branch (nn.BatchNorm2d): A branch only with bn in the block.
+
+ Returns:
+ (mmcv.runner.Sequential): a sequential with conv and bn.
+ """
+ input_dim = self.in_channels // self.groups
+ conv_weight = torch.zeros(
+ (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
+ dtype=branch_nrom.weight.dtype)
+
+ for i in range(self.in_channels):
+ conv_weight[i, i % input_dim, self.kernel_size // 2,
+ self.kernel_size // 2] = 1
+ conv_weight = conv_weight.to(branch_nrom.weight.device)
+
+ tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size)
+ tmp_conv.conv.weight.data = conv_weight
+ tmp_conv.norm = branch_nrom
+ return tmp_conv
+
+
+@MODELS.register_module()
+class MobileOne(BaseBackbone):
+ """MobileOne backbone.
+
+ A PyTorch impl of : `An Improved One millisecond Mobile Backbone
+ `_
+
+ Args:
+ arch (str | dict): MobileOne architecture. If use string, choose
+ from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should
+ have below keys:
+
+ - num_blocks (Sequence[int]): Number of blocks in each stage.
+ - width_factor (Sequence[float]): Width factor in each stage.
+ - num_conv_branches (Sequence[int]): Number of conv branches
+ in each stage.
+ - num_se_blocks (Sequence[int]): Number of SE layers in each
+ stage, all the SE layers are placed in the subsequent order
+ in each stage.
+
+ Defaults to 's0'.
+ in_channels (int): Number of input image channels. Default: 3.
+ out_indices (Sequence[int] | int): Output from which stages.
+ Defaults to ``(3, )``.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters. Defaults to -1.
+ conv_cfg (dict | None): The config dict for conv layers.
+ Defaults to None.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='ReLU')``.
+ deploy (bool): Whether to switch the model structure to deployment
+ mode. Defaults to False.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+
+ Example:
+ >>> from mmpretrain.models import MobileOne
+ >>> import torch
+ >>> x = torch.rand(1, 3, 224, 224)
+ >>> model = MobileOne("s0", out_indices=(0, 1, 2, 3))
+ >>> model.eval()
+ >>> outputs = model(x)
+ >>> for out in outputs:
+ ... print(tuple(out.shape))
+ (1, 48, 56, 56)
+ (1, 128, 28, 28)
+ (1, 256, 14, 14)
+ (1, 1024, 7, 7)
+ """
+
+ arch_zoo = {
+ 's0':
+ dict(
+ num_blocks=[2, 8, 10, 1],
+ width_factor=[0.75, 1.0, 1.0, 2.0],
+ num_conv_branches=[4, 4, 4, 4],
+ num_se_blocks=[0, 0, 0, 0]),
+ 's1':
+ dict(
+ num_blocks=[2, 8, 10, 1],
+ width_factor=[1.5, 1.5, 2.0, 2.5],
+ num_conv_branches=[1, 1, 1, 1],
+ num_se_blocks=[0, 0, 0, 0]),
+ 's2':
+ dict(
+ num_blocks=[2, 8, 10, 1],
+ width_factor=[1.5, 2.0, 2.5, 4.0],
+ num_conv_branches=[1, 1, 1, 1],
+ num_se_blocks=[0, 0, 0, 0]),
+ 's3':
+ dict(
+ num_blocks=[2, 8, 10, 1],
+ width_factor=[2.0, 2.5, 3.0, 4.0],
+ num_conv_branches=[1, 1, 1, 1],
+ num_se_blocks=[0, 0, 0, 0]),
+ 's4':
+ dict(
+ num_blocks=[2, 8, 10, 1],
+ width_factor=[3.0, 3.5, 3.5, 4.0],
+ num_conv_branches=[1, 1, 1, 1],
+ num_se_blocks=[0, 0, 5, 1])
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ out_indices=(3, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ se_cfg=dict(ratio=16),
+ deploy=False,
+ norm_eval=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(type='Constant', val=1, layer=['_BatchNorm'])
+ ]):
+ super(MobileOne, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_zoo, f'"arch": "{arch}"' \
+ f' is not one of the {list(self.arch_zoo.keys())}'
+ arch = self.arch_zoo[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ self.arch = arch
+ for k, value in self.arch.items():
+ assert isinstance(value, list) and len(value) == 4, \
+ f'the value of {k} in arch must be list with 4 items.'
+
+ self.in_channels = in_channels
+ self.deploy = deploy
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.se_cfg = se_cfg
+ self.act_cfg = act_cfg
+
+ base_channels = [64, 128, 256, 512]
+ channels = min(64,
+ int(base_channels[0] * self.arch['width_factor'][0]))
+ self.stage0 = MobileOneBlock(
+ self.in_channels,
+ channels,
+ stride=2,
+ kernel_size=3,
+ num_convs=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ deploy=deploy)
+
+ self.in_planes = channels
+ self.stages = []
+ for i, num_blocks in enumerate(self.arch['num_blocks']):
+ planes = int(base_channels[i] * self.arch['width_factor'][i])
+
+ stage = self._make_stage(planes, num_blocks,
+ arch['num_se_blocks'][i],
+ arch['num_conv_branches'][i])
+
+ stage_name = f'stage{i + 1}'
+ self.add_module(stage_name, stage)
+ self.stages.append(stage_name)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ out_indices = list(out_indices)
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = len(self.stages) + index
+ assert 0 <= out_indices[i] <= len(self.stages), \
+ f'Invalid out_indices {index}.'
+ self.out_indices = out_indices
+
+ def _make_stage(self, planes, num_blocks, num_se, num_conv_branches):
+ strides = [2] + [1] * (num_blocks - 1)
+ if num_se > num_blocks:
+ raise ValueError('Number of SE blocks cannot '
+ 'exceed number of layers.')
+ blocks = []
+ for i in range(num_blocks):
+ use_se = False
+ if i >= (num_blocks - num_se):
+ use_se = True
+
+ blocks.append(
+ # Depthwise conv
+ MobileOneBlock(
+ in_channels=self.in_planes,
+ out_channels=self.in_planes,
+ kernel_size=3,
+ num_convs=num_conv_branches,
+ stride=strides[i],
+ padding=1,
+ groups=self.in_planes,
+ se_cfg=self.se_cfg if use_se else None,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ deploy=self.deploy))
+
+ blocks.append(
+ # Pointwise conv
+ MobileOneBlock(
+ in_channels=self.in_planes,
+ out_channels=planes,
+ kernel_size=1,
+ num_convs=num_conv_branches,
+ stride=1,
+ padding=0,
+ se_cfg=self.se_cfg if use_se else None,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ deploy=self.deploy))
+
+ self.in_planes = planes
+
+ return Sequential(*blocks)
+
+ def forward(self, x):
+ x = self.stage0(x)
+ outs = []
+ for i, stage_name in enumerate(self.stages):
+ stage = getattr(self, stage_name)
+ x = stage(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.stage0.eval()
+ for param in self.stage0.parameters():
+ param.requires_grad = False
+ for i in range(self.frozen_stages):
+ stage = getattr(self, f'stage{i+1}')
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ """switch the mobile to train mode or not."""
+ super(MobileOne, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def switch_to_deploy(self):
+ """switch the model to deploy mode, which has smaller amount of
+ parameters and calculations."""
+ for m in self.modules():
+ if isinstance(m, MobileOneBlock):
+ m.switch_to_deploy()
+ self.deploy = True
diff --git a/mmpretrain/models/backbones/mobilevit.py b/mmpretrain/models/backbones/mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e4043fe46049a4d1bddecc6b7b3768236318e82
--- /dev/null
+++ b/mmpretrain/models/backbones/mobilevit.py
@@ -0,0 +1,431 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Callable, Optional, Sequence
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_norm_layer
+from torch import nn
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+from .mobilenet_v2 import InvertedResidual
+from .vision_transformer import TransformerEncoderLayer
+
+
+class MobileVitBlock(nn.Module):
+ """MobileViT block.
+
+ According to the paper, the MobileViT block has a local representation.
+ a transformer-as-convolution layer which consists of a global
+ representation with unfolding and folding, and a final fusion layer.
+
+ Args:
+ in_channels (int): Number of input image channels.
+ transformer_dim (int): Number of transformer channels.
+ ffn_dim (int): Number of ffn channels in transformer block.
+ out_channels (int): Number of channels in output.
+ conv_ksize (int): Conv kernel size in local representation
+ and fusion. Defaults to 3.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Defaults to dict(type='BN').
+ act_cfg (dict, optional): Config dict for activation layer.
+ Defaults to dict(type='Swish').
+ num_transformer_blocks (int): Number of transformer blocks in
+ a MobileViT block. Defaults to 2.
+ patch_size (int): Patch size for unfolding and folding.
+ Defaults to 2.
+ num_heads (int): Number of heads in global representation.
+ Defaults to 4.
+ drop_rate (float): Probability of an element to be zeroed
+ after the feed forward layer. Defaults to 0.
+ attn_drop_rate (float): The drop out rate for attention output weights.
+ Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ no_fusion (bool): Whether to remove the fusion layer.
+ Defaults to False.
+ transformer_norm_cfg (dict, optional): Config dict for normalization
+ layer in transformer. Defaults to dict(type='LN').
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ transformer_dim: int,
+ ffn_dim: int,
+ out_channels: int,
+ conv_ksize: int = 3,
+ conv_cfg: Optional[dict] = None,
+ norm_cfg: Optional[dict] = dict(type='BN'),
+ act_cfg: Optional[dict] = dict(type='Swish'),
+ num_transformer_blocks: int = 2,
+ patch_size: int = 2,
+ num_heads: int = 4,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ no_fusion: bool = False,
+ transformer_norm_cfg: Callable = dict(type='LN'),
+ ):
+ super(MobileVitBlock, self).__init__()
+
+ self.local_rep = nn.Sequential(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=conv_ksize,
+ padding=int((conv_ksize - 1) / 2),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=transformer_dim,
+ kernel_size=1,
+ bias=False,
+ conv_cfg=conv_cfg,
+ norm_cfg=None,
+ act_cfg=None),
+ )
+
+ global_rep = [
+ TransformerEncoderLayer(
+ embed_dims=transformer_dim,
+ num_heads=num_heads,
+ feedforward_channels=ffn_dim,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ qkv_bias=True,
+ act_cfg=dict(type='Swish'),
+ norm_cfg=transformer_norm_cfg)
+ for _ in range(num_transformer_blocks)
+ ]
+ global_rep.append(
+ build_norm_layer(transformer_norm_cfg, transformer_dim)[1])
+ self.global_rep = nn.Sequential(*global_rep)
+
+ self.conv_proj = ConvModule(
+ in_channels=transformer_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if no_fusion:
+ self.conv_fusion = None
+ else:
+ self.conv_fusion = ConvModule(
+ in_channels=in_channels + out_channels,
+ out_channels=out_channels,
+ kernel_size=conv_ksize,
+ padding=int((conv_ksize - 1) / 2),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.patch_size = (patch_size, patch_size)
+ self.patch_area = self.patch_size[0] * self.patch_size[1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+
+ # Local representation
+ x = self.local_rep(x)
+
+ # Unfold (feature map -> patches)
+ patch_h, patch_w = self.patch_size
+ B, C, H, W = x.shape
+ new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(
+ W / patch_w) * patch_w
+ num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa
+ num_patches = num_patch_h * num_patch_w # N
+ interpolate = False
+ if new_h != H or new_w != W:
+ # Note: Padding can be done, but then it needs to be handled in attention function. # noqa
+ x = F.interpolate(
+ x, size=(new_h, new_w), mode='bilinear', align_corners=False)
+ interpolate = True
+
+ # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
+ x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w,
+ patch_w).transpose(1, 2)
+ # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa
+ x = x.reshape(B, C, num_patches,
+ self.patch_area).transpose(1, 3).reshape(
+ B * self.patch_area, num_patches, -1)
+
+ # Global representations
+ x = self.global_rep(x)
+
+ # Fold (patch -> feature map)
+ # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
+ x = x.contiguous().view(B, self.patch_area, num_patches, -1)
+ x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w,
+ patch_h, patch_w)
+ # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa
+ x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h,
+ num_patch_w * patch_w)
+ if interpolate:
+ x = F.interpolate(
+ x, size=(H, W), mode='bilinear', align_corners=False)
+
+ x = self.conv_proj(x)
+ if self.conv_fusion is not None:
+ x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
+ return x
+
+
+@MODELS.register_module()
+class MobileViT(BaseBackbone):
+ """MobileViT backbone.
+
+ A PyTorch implementation of : `MobileViT: Light-weight, General-purpose,
+ and Mobile-friendly Vision Transformer `_
+
+ Modified from the `official repo
+ `_
+ and `timm
+ `_.
+
+ Args:
+ arch (str | List[list]): Architecture of MobileViT.
+
+ - If a string, choose from "small", "x_small" and "xx_small".
+
+ - If a list, every item should be also a list, and the first item
+ of the sub-list can be chosen from "moblienetv2" and "mobilevit",
+ which indicates the type of this layer sequence. If "mobilenetv2",
+ the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer`
+ (except ``in_channels``) and if "mobilevit", the other items are
+ the arguments of :attr:`~MobileViT.make_mobilevit_layer`
+ (except ``in_channels``).
+
+ Defaults to "small".
+ in_channels (int): Number of input image channels. Defaults to 3.
+ stem_channels (int): Channels of stem layer. Defaults to 16.
+ last_exp_factor (int): Channels expand factor of last layer.
+ Defaults to 4.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to (4, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to -1, which means not freezing any parameters.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Defaults to dict(type='BN').
+ act_cfg (dict, optional): Config dict for activation layer.
+ Defaults to dict(type='Swish').
+ init_cfg (dict, optional): Initialization config dict.
+ """ # noqa
+
+ # Parameters to build layers. The first param is the type of layer.
+ # For `mobilenetv2` layer, the rest params from left to right are:
+ # out channels, stride, num of blocks, expand_ratio.
+ # For `mobilevit` layer, the rest params from left to right are:
+ # out channels, stride, transformer_channels, ffn channels,
+ # num of transformer blocks, expand_ratio.
+ arch_settings = {
+ 'small': [
+ ['mobilenetv2', 32, 1, 1, 4],
+ ['mobilenetv2', 64, 2, 3, 4],
+ ['mobilevit', 96, 2, 144, 288, 2, 4],
+ ['mobilevit', 128, 2, 192, 384, 4, 4],
+ ['mobilevit', 160, 2, 240, 480, 3, 4],
+ ],
+ 'x_small': [
+ ['mobilenetv2', 32, 1, 1, 4],
+ ['mobilenetv2', 48, 2, 3, 4],
+ ['mobilevit', 64, 2, 96, 192, 2, 4],
+ ['mobilevit', 80, 2, 120, 240, 4, 4],
+ ['mobilevit', 96, 2, 144, 288, 3, 4],
+ ],
+ 'xx_small': [
+ ['mobilenetv2', 16, 1, 1, 2],
+ ['mobilenetv2', 24, 2, 3, 2],
+ ['mobilevit', 48, 2, 64, 128, 2, 2],
+ ['mobilevit', 64, 2, 80, 160, 4, 2],
+ ['mobilevit', 80, 2, 96, 192, 3, 2],
+ ]
+ }
+
+ def __init__(self,
+ arch='small',
+ in_channels=3,
+ stem_channels=16,
+ last_exp_factor=4,
+ out_indices=(4, ),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='Swish'),
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(MobileViT, self).__init__(init_cfg)
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a list.'
+ arch = self.arch_settings[arch]
+
+ self.arch = arch
+ self.num_stages = len(arch)
+
+ # check out indices and frozen stages
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = self.num_stages + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+
+ if frozen_stages not in range(-1, self.num_stages):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{self.num_stages}). '
+ f'But received {frozen_stages}')
+ self.frozen_stages = frozen_stages
+
+ _make_layer_func = {
+ 'mobilenetv2': self.make_mobilenetv2_layer,
+ 'mobilevit': self.make_mobilevit_layer,
+ }
+
+ self.stem = ConvModule(
+ in_channels=in_channels,
+ out_channels=stem_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ in_channels = stem_channels
+ layers = []
+ for i, layer_settings in enumerate(arch):
+ layer_type, settings = layer_settings[0], layer_settings[1:]
+ layer, out_channels = _make_layer_func[layer_type](in_channels,
+ *settings)
+ layers.append(layer)
+ in_channels = out_channels
+ self.layers = nn.Sequential(*layers)
+
+ self.conv_1x1_exp = ConvModule(
+ in_channels=in_channels,
+ out_channels=last_exp_factor * in_channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ @staticmethod
+ def make_mobilevit_layer(in_channels,
+ out_channels,
+ stride,
+ transformer_dim,
+ ffn_dim,
+ num_transformer_blocks,
+ expand_ratio=4):
+ """Build mobilevit layer, which consists of one InvertedResidual and
+ one MobileVitBlock.
+
+ Args:
+ in_channels (int): The input channels.
+ out_channels (int): The output channels.
+ stride (int): The stride of the first 3x3 convolution in the
+ ``InvertedResidual`` layers.
+ transformer_dim (int): The channels of the transformer layers.
+ ffn_dim (int): The mid-channels of the feedforward network in
+ transformer layers.
+ num_transformer_blocks (int): The number of transformer blocks.
+ expand_ratio (int): adjusts number of channels of the hidden layer
+ in ``InvertedResidual`` by this amount. Defaults to 4.
+ """
+ layer = []
+ layer.append(
+ InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ act_cfg=dict(type='Swish'),
+ ))
+ layer.append(
+ MobileVitBlock(
+ in_channels=out_channels,
+ transformer_dim=transformer_dim,
+ ffn_dim=ffn_dim,
+ out_channels=out_channels,
+ num_transformer_blocks=num_transformer_blocks,
+ ))
+ return nn.Sequential(*layer), out_channels
+
+ @staticmethod
+ def make_mobilenetv2_layer(in_channels,
+ out_channels,
+ stride,
+ num_blocks,
+ expand_ratio=4):
+ """Build mobilenetv2 layer, which consists of several InvertedResidual
+ layers.
+
+ Args:
+ in_channels (int): The input channels.
+ out_channels (int): The output channels.
+ stride (int): The stride of the first 3x3 convolution in the
+ ``InvertedResidual`` layers.
+ num_blocks (int): The number of ``InvertedResidual`` blocks.
+ expand_ratio (int): adjusts number of channels of the hidden layer
+ in ``InvertedResidual`` by this amount. Defaults to 4.
+ """
+ layer = []
+ for i in range(num_blocks):
+ stride = stride if i == 0 else 1
+
+ layer.append(
+ InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ act_cfg=dict(type='Swish'),
+ ))
+ in_channels = out_channels
+ return nn.Sequential(*layer), out_channels
+
+ def _freeze_stages(self):
+ for i in range(0, self.frozen_stages):
+ layer = self.layers[i]
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileViT, self).train(mode)
+ self._freeze_stages()
+
+ def forward(self, x):
+ x = self.stem(x)
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i == len(self.layers) - 1:
+ x = self.conv_1x1_exp(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/mvit.py b/mmpretrain/models/backbones/mvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..68aee97ddf3077ca58e488f38e9d9422b171d691
--- /dev/null
+++ b/mmpretrain/models/backbones/mvit.py
@@ -0,0 +1,700 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks import DropPath
+from mmcv.cnn.bricks.transformer import PatchEmbed
+from mmengine.model import BaseModule, ModuleList
+from mmengine.model.weight_init import trunc_normal_
+from mmengine.utils import to_2tuple
+
+from ..builder import BACKBONES
+from ..utils import resize_pos_embed
+from .base_backbone import BaseBackbone
+
+
+def resize_decomposed_rel_pos(rel_pos, q_size, k_size):
+ """Get relative positional embeddings according to the relative positions
+ of query and key sizes.
+
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ resized = F.interpolate(
+ # (L, C) -> (1, C, L)
+ rel_pos.transpose(0, 1).unsqueeze(0),
+ size=max_rel_dist,
+ mode='linear',
+ )
+ # (1, C, L) -> (L, C)
+ resized = resized.squeeze(0).transpose(0, 1)
+ else:
+ resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_h_ratio = max(k_size / q_size, 1.0)
+ k_h_ratio = max(q_size / k_size, 1.0)
+ q_coords = torch.arange(q_size)[:, None] * q_h_ratio
+ k_coords = torch.arange(k_size)[None, :] * k_h_ratio
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio
+
+ return resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(attn,
+ q,
+ q_shape,
+ k_shape,
+ rel_pos_h,
+ rel_pos_w,
+ has_cls_token=False):
+ """Spatial Relative Positional Embeddings."""
+ sp_idx = 1 if has_cls_token else 0
+ B, num_heads, _, C = q.shape
+ q_h, q_w = q_shape
+ k_h, k_w = k_shape
+
+ Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h)
+ Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w)
+
+ r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C)
+ rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh)
+ rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw)
+ rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]
+
+ attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ attn_map += rel_pos_embed
+ attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class MLP(BaseModule):
+ """Two-layer multilayer perceptron.
+
+ Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows
+ different input and output channel numbers.
+
+ Args:
+ in_channels (int): The number of input channels.
+ hidden_channels (int, optional): The number of hidden layer channels.
+ If None, same as the ``in_channels``. Defaults to None.
+ out_channels (int, optional): The number of output channels. If None,
+ same as the ``in_channels``. Defaults to None.
+ act_cfg (dict): The config of activation function.
+ Defaults to ``dict(type='GELU')``.
+ init_cfg (dict, optional): The config of weight initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ hidden_channels=None,
+ out_channels=None,
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_channels = out_channels or in_channels
+ hidden_channels = hidden_channels or in_channels
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = nn.Linear(hidden_channels, out_channels)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+def attention_pool(x: torch.Tensor,
+ pool: nn.Module,
+ in_size: tuple,
+ norm: Optional[nn.Module] = None):
+ """Pooling the feature tokens.
+
+ Args:
+ x (torch.Tensor): The input tensor, should be with shape
+ ``(B, num_heads, L, C)`` or ``(B, L, C)``.
+ pool (nn.Module): The pooling module.
+ in_size (Tuple[int]): The shape of the input feature map.
+ norm (nn.Module, optional): The normalization module.
+ Defaults to None.
+ """
+ ndim = x.ndim
+ if ndim == 4:
+ B, num_heads, L, C = x.shape
+ elif ndim == 3:
+ num_heads = 1
+ B, L, C = x.shape
+ else:
+ raise RuntimeError(f'Unsupported input dimension {x.shape}')
+
+ H, W = in_size
+ assert L == H * W
+
+ # (B, num_heads, H*W, C) -> (B*num_heads, C, H, W)
+ x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous()
+ x = pool(x)
+ out_size = x.shape[-2:]
+
+ # (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C)
+ x = x.reshape(B, num_heads, C, -1).transpose(2, 3)
+
+ if norm is not None:
+ x = norm(x)
+
+ if ndim == 3:
+ x = x.squeeze(1)
+
+ return x, out_size
+
+
+class MultiScaleAttention(BaseModule):
+ """Multiscale Multi-head Attention block.
+
+ Args:
+ in_dims (int): Number of input channels.
+ out_dims (int): Number of output channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key and
+ value. Defaults to True.
+ norm_cfg (dict): The config of normalization layers.
+ Defaults to ``dict(type='LN')``.
+ pool_kernel (tuple): kernel size for qkv pooling layers.
+ Defaults to (3, 3).
+ stride_q (int): stride size for q pooling layer. Defaults to 1.
+ stride_kv (int): stride size for kv pooling layer. Defaults to 1.
+ rel_pos_spatial (bool): Whether to enable the spatial relative
+ position embedding. Defaults to True.
+ residual_pooling (bool): Whether to enable the residual connection
+ after attention pooling. Defaults to True.
+ input_size (Tuple[int], optional): The input resolution, necessary
+ if enable the ``rel_pos_spatial``. Defaults to None.
+ rel_pos_zero_init (bool): If True, zero initialize relative
+ positional parameters. Defaults to False.
+ init_cfg (dict, optional): The config of weight initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_dims,
+ out_dims,
+ num_heads,
+ qkv_bias=True,
+ norm_cfg=dict(type='LN'),
+ pool_kernel=(3, 3),
+ stride_q=1,
+ stride_kv=1,
+ rel_pos_spatial=False,
+ residual_pooling=True,
+ input_size=None,
+ rel_pos_zero_init=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.num_heads = num_heads
+ self.in_dims = in_dims
+ self.out_dims = out_dims
+
+ head_dim = out_dims // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias)
+ self.proj = nn.Linear(out_dims, out_dims)
+
+ # qkv pooling
+ pool_padding = [k // 2 for k in pool_kernel]
+ pool_dims = out_dims // num_heads
+
+ def build_pooling(stride):
+ pool = nn.Conv2d(
+ pool_dims,
+ pool_dims,
+ pool_kernel,
+ stride=stride,
+ padding=pool_padding,
+ groups=pool_dims,
+ bias=False,
+ )
+ norm = build_norm_layer(norm_cfg, pool_dims)[1]
+ return pool, norm
+
+ self.pool_q, self.norm_q = build_pooling(stride_q)
+ self.pool_k, self.norm_k = build_pooling(stride_kv)
+ self.pool_v, self.norm_v = build_pooling(stride_kv)
+
+ self.residual_pooling = residual_pooling
+
+ self.rel_pos_spatial = rel_pos_spatial
+ self.rel_pos_zero_init = rel_pos_zero_init
+ if self.rel_pos_spatial:
+ # initialize relative positional embeddings
+ assert input_size[0] == input_size[1]
+
+ size = input_size[0]
+ rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1
+ self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim))
+
+ def init_weights(self):
+ """Weight initialization."""
+ super().init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress rel_pos_zero_init if use pretrained model.
+ return
+
+ if not self.rel_pos_zero_init:
+ trunc_normal_(self.rel_pos_h, std=0.02)
+ trunc_normal_(self.rel_pos_w, std=0.02)
+
+ def forward(self, x, in_size):
+ """Forward the MultiScaleAttention."""
+ B, N, _ = x.shape # (B, H*W, C)
+
+ # qkv: (B, H*W, 3, num_heads, C)
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1)
+ # q, k, v: (B, num_heads, H*W, C)
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
+
+ q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q)
+ k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k)
+ v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ if self.rel_pos_spatial:
+ attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape,
+ self.rel_pos_h, self.rel_pos_w)
+
+ attn = attn.softmax(dim=-1)
+ x = attn @ v
+
+ if self.residual_pooling:
+ x = x + q
+
+ # (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C')
+ x = x.transpose(1, 2).reshape(B, -1, self.out_dims)
+ x = self.proj(x)
+
+ return x, q_shape
+
+
+class MultiScaleBlock(BaseModule):
+ """Multiscale Transformer blocks.
+
+ Args:
+ in_dims (int): Number of input channels.
+ out_dims (int): Number of output channels.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
+ Defaults to 4.0.
+ qkv_bias (bool): If True, add a learnable bias to query, key and
+ value. Defaults to True.
+ drop_path (float): Stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): The config of normalization layers.
+ Defaults to ``dict(type='LN')``.
+ act_cfg (dict): The config of activation function.
+ Defaults to ``dict(type='GELU')``.
+ qkv_pool_kernel (tuple): kernel size for qkv pooling layers.
+ Defaults to (3, 3).
+ stride_q (int): stride size for q pooling layer. Defaults to 1.
+ stride_kv (int): stride size for kv pooling layer. Defaults to 1.
+ rel_pos_spatial (bool): Whether to enable the spatial relative
+ position embedding. Defaults to True.
+ residual_pooling (bool): Whether to enable the residual connection
+ after attention pooling. Defaults to True.
+ dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
+ attention layers. If False, multiply it in MLP layers.
+ Defaults to True.
+ input_size (Tuple[int], optional): The input resolution, necessary
+ if enable the ``rel_pos_spatial``. Defaults to None.
+ rel_pos_zero_init (bool): If True, zero initialize relative
+ positional parameters. Defaults to False.
+ init_cfg (dict, optional): The config of weight initialization.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ in_dims,
+ out_dims,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=0.0,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ qkv_pool_kernel=(3, 3),
+ stride_q=1,
+ stride_kv=1,
+ rel_pos_spatial=True,
+ residual_pooling=True,
+ dim_mul_in_attention=True,
+ input_size=None,
+ rel_pos_zero_init=False,
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ self.in_dims = in_dims
+ self.out_dims = out_dims
+ self.norm1 = build_norm_layer(norm_cfg, in_dims)[1]
+ self.dim_mul_in_attention = dim_mul_in_attention
+
+ attn_dims = out_dims if dim_mul_in_attention else in_dims
+ self.attn = MultiScaleAttention(
+ in_dims,
+ attn_dims,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ norm_cfg=norm_cfg,
+ pool_kernel=qkv_pool_kernel,
+ stride_q=stride_q,
+ stride_kv=stride_kv,
+ rel_pos_spatial=rel_pos_spatial,
+ residual_pooling=residual_pooling,
+ input_size=input_size,
+ rel_pos_zero_init=rel_pos_zero_init)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1]
+
+ self.mlp = MLP(
+ in_channels=attn_dims,
+ hidden_channels=int(attn_dims * mlp_ratio),
+ out_channels=out_dims,
+ act_cfg=act_cfg)
+
+ if in_dims != out_dims:
+ self.proj = nn.Linear(in_dims, out_dims)
+ else:
+ self.proj = None
+
+ if stride_q > 1:
+ kernel_skip = stride_q + 1
+ padding_skip = int(kernel_skip // 2)
+ self.pool_skip = nn.MaxPool2d(
+ kernel_skip, stride_q, padding_skip, ceil_mode=False)
+
+ if input_size is not None:
+ input_size = to_2tuple(input_size)
+ out_size = [size // stride_q for size in input_size]
+ self.init_out_size = out_size
+ else:
+ self.init_out_size = None
+ else:
+ self.pool_skip = None
+ self.init_out_size = input_size
+
+ def forward(self, x, in_size):
+ x_norm = self.norm1(x)
+ x_attn, out_size = self.attn(x_norm, in_size)
+
+ if self.dim_mul_in_attention and self.proj is not None:
+ skip = self.proj(x_norm)
+ else:
+ skip = x
+
+ if self.pool_skip is not None:
+ skip, _ = attention_pool(skip, self.pool_skip, in_size)
+
+ x = skip + self.drop_path(x_attn)
+ x_norm = self.norm2(x)
+ x_mlp = self.mlp(x_norm)
+
+ if not self.dim_mul_in_attention and self.proj is not None:
+ skip = self.proj(x_norm)
+ else:
+ skip = x
+
+ x = skip + self.drop_path(x_mlp)
+
+ return x, out_size
+
+
+@BACKBONES.register_module()
+class MViT(BaseBackbone):
+ """Multi-scale ViT v2.
+
+ A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers
+ for Classification and Detection `_
+
+ Inspiration from `the official implementation
+ `_ and `the detectron2
+ implementation `_
+
+ Args:
+ arch (str | dict): MViT architecture. If use string, choose
+ from 'tiny', 'small', 'base' and 'large'. If use dict, it should
+ have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **num_layers** (int): The number of layers.
+ - **num_heads** (int): The number of heads in attention
+ modules of the initial layer.
+ - **downscale_indices** (List[int]): The layer indices to downscale
+ the feature map.
+
+ Defaults to 'base'.
+ img_size (int): The expected input image shape. Defaults to 224.
+ in_channels (int): The num of input channels. Defaults to 3.
+ out_scales (int | Sequence[int]): The output scale indices.
+ They should not exceed the length of ``downscale_indices``.
+ Defaults to -1, which means the last scale.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
+ the patch embedding. Defaults to False.
+ interpolate_mode (str): Select the interpolate mode for absolute
+ position embedding vector resize. Defaults to "bicubic".
+ pool_kernel (tuple): kernel size for qkv pooling layers.
+ Defaults to (3, 3).
+ dim_mul (int): The magnification for ``embed_dims`` in the downscale
+ layers. Defaults to 2.
+ head_mul (int): The magnification for ``num_heads`` in the downscale
+ layers. Defaults to 2.
+ adaptive_kv_stride (int): The stride size for kv pooling in the initial
+ layer. Defaults to 4.
+ rel_pos_spatial (bool): Whether to enable the spatial relative position
+ embedding. Defaults to True.
+ residual_pooling (bool): Whether to enable the residual connection
+ after attention pooling. Defaults to True.
+ dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
+ attention layers. If False, multiply it in MLP layers.
+ Defaults to True.
+ rel_pos_zero_init (bool): If True, zero initialize relative
+ positional parameters. Defaults to False.
+ mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
+ Defaults to 4.0.
+ qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+ norm_cfg (dict): Config dict for normalization layer for all output
+ features. Defaults to ``dict(type='LN', eps=1e-6)``.
+ patch_cfg (dict): Config dict for the patch embedding layer.
+ Defaults to ``dict(kernel_size=7, stride=4, padding=3)``.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+
+ Examples:
+ >>> import torch
+ >>> from mmpretrain.models import build_backbone
+ >>>
+ >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
+ >>> model = build_backbone(cfg)
+ >>> inputs = torch.rand(1, 3, 224, 224)
+ >>> outputs = model(inputs)
+ >>> for i, output in enumerate(outputs):
+ >>> print(f'scale{i}: {output.shape}')
+ scale0: torch.Size([1, 96, 56, 56])
+ scale1: torch.Size([1, 192, 28, 28])
+ scale2: torch.Size([1, 384, 14, 14])
+ scale3: torch.Size([1, 768, 7, 7])
+ """
+ arch_zoo = {
+ 'tiny': {
+ 'embed_dims': 96,
+ 'num_layers': 10,
+ 'num_heads': 1,
+ 'downscale_indices': [1, 3, 8]
+ },
+ 'small': {
+ 'embed_dims': 96,
+ 'num_layers': 16,
+ 'num_heads': 1,
+ 'downscale_indices': [1, 3, 14]
+ },
+ 'base': {
+ 'embed_dims': 96,
+ 'num_layers': 24,
+ 'num_heads': 1,
+ 'downscale_indices': [2, 5, 21]
+ },
+ 'large': {
+ 'embed_dims': 144,
+ 'num_layers': 48,
+ 'num_heads': 2,
+ 'downscale_indices': [2, 8, 44]
+ },
+ }
+ num_extra_tokens = 0
+
+ def __init__(self,
+ arch='base',
+ img_size=224,
+ in_channels=3,
+ out_scales=-1,
+ drop_path_rate=0.,
+ use_abs_pos_embed=False,
+ interpolate_mode='bicubic',
+ pool_kernel=(3, 3),
+ dim_mul=2,
+ head_mul=2,
+ adaptive_kv_stride=4,
+ rel_pos_spatial=True,
+ residual_pooling=True,
+ dim_mul_in_attention=True,
+ rel_pos_zero_init=False,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ patch_cfg=dict(kernel_size=7, stride=4, padding=3),
+ init_cfg=None):
+ super().__init__(init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'num_layers', 'num_heads', 'downscale_indices'
+ }
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.num_layers = self.arch_settings['num_layers']
+ self.num_heads = self.arch_settings['num_heads']
+ self.downscale_indices = self.arch_settings['downscale_indices']
+ self.num_scales = len(self.downscale_indices) + 1
+ self.stage_indices = {
+ index - 1: i
+ for i, index in enumerate(self.downscale_indices)
+ }
+ self.stage_indices[self.num_layers - 1] = self.num_scales - 1
+ self.use_abs_pos_embed = use_abs_pos_embed
+ self.interpolate_mode = interpolate_mode
+
+ if isinstance(out_scales, int):
+ out_scales = [out_scales]
+ assert isinstance(out_scales, Sequence), \
+ f'"out_scales" must by a sequence or int, ' \
+ f'get {type(out_scales)} instead.'
+ for i, index in enumerate(out_scales):
+ if index < 0:
+ out_scales[i] = self.num_scales + index
+ assert 0 <= out_scales[i] <= self.num_scales, \
+ f'Invalid out_scales {index}'
+ self.out_scales = sorted(list(out_scales))
+
+ # Set patch embedding
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+
+ # Set absolute position embedding
+ if self.use_abs_pos_embed:
+ num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, self.embed_dims))
+
+ # stochastic depth decay rule
+ dpr = np.linspace(0, drop_path_rate, self.num_layers)
+
+ self.blocks = ModuleList()
+ out_dims_list = [self.embed_dims]
+ num_heads = self.num_heads
+ stride_kv = adaptive_kv_stride
+ input_size = self.patch_resolution
+ for i in range(self.num_layers):
+ if i in self.downscale_indices:
+ num_heads *= head_mul
+ stride_q = 2
+ stride_kv = max(stride_kv // 2, 1)
+ else:
+ stride_q = 1
+
+ # Set output embed_dims
+ if dim_mul_in_attention and i in self.downscale_indices:
+ # multiply embed_dims in downscale layers.
+ out_dims = out_dims_list[-1] * dim_mul
+ elif not dim_mul_in_attention and i + 1 in self.downscale_indices:
+ # multiply embed_dims before downscale layers.
+ out_dims = out_dims_list[-1] * dim_mul
+ else:
+ out_dims = out_dims_list[-1]
+
+ attention_block = MultiScaleBlock(
+ in_dims=out_dims_list[-1],
+ out_dims=out_dims,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path=dpr[i],
+ norm_cfg=norm_cfg,
+ qkv_pool_kernel=pool_kernel,
+ stride_q=stride_q,
+ stride_kv=stride_kv,
+ rel_pos_spatial=rel_pos_spatial,
+ residual_pooling=residual_pooling,
+ dim_mul_in_attention=dim_mul_in_attention,
+ input_size=input_size,
+ rel_pos_zero_init=rel_pos_zero_init)
+ self.blocks.append(attention_block)
+
+ input_size = attention_block.init_out_size
+ out_dims_list.append(out_dims)
+
+ if i in self.stage_indices:
+ stage_index = self.stage_indices[i]
+ if stage_index in self.out_scales:
+ norm_layer = build_norm_layer(norm_cfg, out_dims)[1]
+ self.add_module(f'norm{stage_index}', norm_layer)
+
+ def init_weights(self):
+ super().init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress default init if use pretrained model.
+ return
+
+ if self.use_abs_pos_embed:
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ def forward(self, x):
+ """Forward the MViT."""
+ B = x.shape[0]
+ x, patch_resolution = self.patch_embed(x)
+
+ if self.use_abs_pos_embed:
+ x = x + resize_pos_embed(
+ self.pos_embed,
+ self.patch_resolution,
+ patch_resolution,
+ mode=self.interpolate_mode,
+ num_extra_tokens=self.num_extra_tokens)
+
+ outs = []
+ for i, block in enumerate(self.blocks):
+ x, patch_resolution = block(x, patch_resolution)
+
+ if i in self.stage_indices:
+ stage_index = self.stage_indices[i]
+ if stage_index in self.out_scales:
+ B, _, C = x.shape
+ x = getattr(self, f'norm{stage_index}')(x)
+ out = x.transpose(1, 2).reshape(B, C, *patch_resolution)
+ outs.append(out.contiguous())
+
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/poolformer.py b/mmpretrain/models/backbones/poolformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ad67043dbeb0ce6969c2770853342b30df2a74
--- /dev/null
+++ b/mmpretrain/models/backbones/poolformer.py
@@ -0,0 +1,416 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
+from mmengine.model import BaseModule
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+class PatchEmbed(nn.Module):
+ """Patch Embedding module implemented by a layer of convolution.
+
+ Input: tensor in shape [B, C, H, W]
+ Output: tensor in shape [B, C, H/stride, W/stride]
+ Args:
+ patch_size (int): Patch size of the patch embedding. Defaults to 16.
+ stride (int): Stride of the patch embedding. Defaults to 16.
+ padding (int): Padding of the patch embedding. Defaults to 0.
+ in_chans (int): Input channels. Defaults to 3.
+ embed_dim (int): Output dimension of the patch embedding.
+ Defaults to 768.
+ norm_layer (module): Normalization module. Defaults to None (not use).
+ """
+
+ def __init__(self,
+ patch_size=16,
+ stride=16,
+ padding=0,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None):
+ super().__init__()
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=padding)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+
+class Pooling(nn.Module):
+ """Pooling module.
+
+ Args:
+ pool_size (int): Pooling size. Defaults to 3.
+ """
+
+ def __init__(self, pool_size=3):
+ super().__init__()
+ self.pool = nn.AvgPool2d(
+ pool_size,
+ stride=1,
+ padding=pool_size // 2,
+ count_include_pad=False)
+
+ def forward(self, x):
+ return self.pool(x) - x
+
+
+class Mlp(nn.Module):
+ """Mlp implemented by with 1*1 convolutions.
+
+ Input: Tensor with shape [B, C, H, W].
+ Output: Tensor with shape [B, C, H, W].
+ Args:
+ in_features (int): Dimension of input features.
+ hidden_features (int): Dimension of hidden features.
+ out_features (int): Dimension of output features.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class PoolFormerBlock(BaseModule):
+ """PoolFormer Block.
+
+ Args:
+ dim (int): Embedding dim.
+ pool_size (int): Pooling size. Defaults to 3.
+ mlp_ratio (float): Mlp expansion ratio. Defaults to 4.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='GN', num_groups=1)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.
+ drop_path (float): Stochastic depth rate. Defaults to 0.
+ layer_scale_init_value (float): Init value for Layer Scale.
+ Defaults to 1e-5.
+ """
+
+ def __init__(self,
+ dim,
+ pool_size=3,
+ mlp_ratio=4.,
+ norm_cfg=dict(type='GN', num_groups=1),
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ drop_path=0.,
+ layer_scale_init_value=1e-5):
+
+ super().__init__()
+
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
+ self.token_mixer = Pooling(pool_size=pool_size)
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ # The following two techniques are useful to train deep PoolFormers.
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
+ else nn.Identity()
+ self.layer_scale_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ self.layer_scale_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ x = x + self.drop_path(
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
+ self.token_mixer(self.norm1(x)))
+ x = x + self.drop_path(
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
+ self.mlp(self.norm2(x)))
+ return x
+
+
+def basic_blocks(dim,
+ index,
+ layers,
+ pool_size=3,
+ mlp_ratio=4.,
+ norm_cfg=dict(type='GN', num_groups=1),
+ act_cfg=dict(type='GELU'),
+ drop_rate=.0,
+ drop_path_rate=0.,
+ layer_scale_init_value=1e-5):
+ """
+ generate PoolFormer blocks for a stage
+ return: PoolFormer blocks
+ """
+ blocks = []
+ for block_idx in range(layers[index]):
+ block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
+ sum(layers) - 1)
+ blocks.append(
+ PoolFormerBlock(
+ dim,
+ pool_size=pool_size,
+ mlp_ratio=mlp_ratio,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ drop=drop_rate,
+ drop_path=block_dpr,
+ layer_scale_init_value=layer_scale_init_value,
+ ))
+ blocks = nn.Sequential(*blocks)
+
+ return blocks
+
+
+@MODELS.register_module()
+class PoolFormer(BaseBackbone):
+ """PoolFormer.
+
+ A PyTorch implementation of PoolFormer introduced by:
+ `MetaFormer is Actually What You Need for Vision `_
+
+ Modified from the `official repo
+ `.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``PoolFormer.arch_settings``. And if dict, it
+ should include the following two keys:
+
+ - layers (list[int]): Number of blocks at each stage.
+ - embed_dims (list[int]): The number of channels at each stage.
+ - mlp_ratios (list[int]): Expansion ratio of MLPs.
+ - layer_scale_init_value (float): Init value for Layer Scale.
+
+ Defaults to 'S12'.
+
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='LN2d', eps=1e-6)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ in_patch_size (int): The patch size of input image patch embedding.
+ Defaults to 7.
+ in_stride (int): The stride of input image patch embedding.
+ Defaults to 4.
+ in_pad (int): The padding of input image patch embedding.
+ Defaults to 2.
+ down_patch_size (int): The patch size of downsampling patch embedding.
+ Defaults to 3.
+ down_stride (int): The stride of downsampling patch embedding.
+ Defaults to 2.
+ down_pad (int): The padding of downsampling patch embedding.
+ Defaults to 1.
+ drop_rate (float): Dropout rate. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ out_indices (Sequence | int): Output from which network position.
+ Index 0-6 respectively corresponds to
+ [stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
+ Defaults to -1, means the last stage.
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ init_cfg (dict, optional): Initialization config dict
+ """ # noqa: E501
+
+ # --layers: [x,x,x,x], numbers of layers for the four stages
+ # --embed_dims, --mlp_ratios:
+ # embedding dims and mlp ratios for the four stages
+ # --downsamples: flags to apply downsampling or not in four blocks
+ arch_settings = {
+ 's12': {
+ 'layers': [2, 2, 6, 2],
+ 'embed_dims': [64, 128, 320, 512],
+ 'mlp_ratios': [4, 4, 4, 4],
+ 'layer_scale_init_value': 1e-5,
+ },
+ 's24': {
+ 'layers': [4, 4, 12, 4],
+ 'embed_dims': [64, 128, 320, 512],
+ 'mlp_ratios': [4, 4, 4, 4],
+ 'layer_scale_init_value': 1e-5,
+ },
+ 's36': {
+ 'layers': [6, 6, 18, 6],
+ 'embed_dims': [64, 128, 320, 512],
+ 'mlp_ratios': [4, 4, 4, 4],
+ 'layer_scale_init_value': 1e-6,
+ },
+ 'm36': {
+ 'layers': [6, 6, 18, 6],
+ 'embed_dims': [96, 192, 384, 768],
+ 'mlp_ratios': [4, 4, 4, 4],
+ 'layer_scale_init_value': 1e-6,
+ },
+ 'm48': {
+ 'layers': [8, 8, 24, 8],
+ 'embed_dims': [96, 192, 384, 768],
+ 'mlp_ratios': [4, 4, 4, 4],
+ 'layer_scale_init_value': 1e-6,
+ },
+ }
+
+ def __init__(self,
+ arch='s12',
+ pool_size=3,
+ norm_cfg=dict(type='GN', num_groups=1),
+ act_cfg=dict(type='GELU'),
+ in_patch_size=7,
+ in_stride=4,
+ in_pad=2,
+ down_patch_size=3,
+ down_stride=2,
+ down_pad=1,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ out_indices=-1,
+ frozen_stages=0,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ assert 'layers' in arch and 'embed_dims' in arch, \
+ f'The arch dict must have "layers" and "embed_dims", ' \
+ f'but got {list(arch.keys())}.'
+
+ layers = arch['layers']
+ embed_dims = arch['embed_dims']
+ mlp_ratios = arch['mlp_ratios'] \
+ if 'mlp_ratios' in arch else [4, 4, 4, 4]
+ layer_scale_init_value = arch['layer_scale_init_value'] \
+ if 'layer_scale_init_value' in arch else 1e-5
+
+ self.patch_embed = PatchEmbed(
+ patch_size=in_patch_size,
+ stride=in_stride,
+ padding=in_pad,
+ in_chans=3,
+ embed_dim=embed_dims[0])
+
+ # set the main block in network
+ network = []
+ for i in range(len(layers)):
+ stage = basic_blocks(
+ embed_dims[i],
+ i,
+ layers,
+ pool_size=pool_size,
+ mlp_ratio=mlp_ratios[i],
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ layer_scale_init_value=layer_scale_init_value)
+ network.append(stage)
+ if i >= len(layers) - 1:
+ break
+ if embed_dims[i] != embed_dims[i + 1]:
+ # downsampling between two stages
+ network.append(
+ PatchEmbed(
+ patch_size=down_patch_size,
+ stride=down_stride,
+ padding=down_pad,
+ in_chans=embed_dims[i],
+ embed_dim=embed_dims[i + 1]))
+
+ self.network = nn.ModuleList(network)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = 7 + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+ self.out_indices = out_indices
+ if self.out_indices:
+ for i_layer in self.out_indices:
+ layer = build_norm_layer(norm_cfg,
+ embed_dims[(i_layer + 1) // 2])[1]
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self.frozen_stages = frozen_stages
+ self._freeze_stages()
+
+ def forward_embeddings(self, x):
+ x = self.patch_embed(x)
+ return x
+
+ def forward_tokens(self, x):
+ outs = []
+ for idx, block in enumerate(self.network):
+ x = block(x)
+ if idx in self.out_indices:
+ norm_layer = getattr(self, f'norm{idx}')
+ x_out = norm_layer(x)
+ outs.append(x_out)
+ return tuple(outs)
+
+ def forward(self, x):
+ # input embedding
+ x = self.forward_embeddings(x)
+ # through backbone
+ x = self.forward_tokens(x)
+ return x
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(self.frozen_stages):
+ # Include both block and downsample layer.
+ module = self.network[i]
+ module.eval()
+ for param in module.parameters():
+ param.requires_grad = False
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(PoolFormer, self).train(mode)
+ self._freeze_stages()
diff --git a/mmpretrain/models/backbones/regnet.py b/mmpretrain/models/backbones/regnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..85dbdef0bfeb607ecddff1d68d1cf405b61bea65
--- /dev/null
+++ b/mmpretrain/models/backbones/regnet.py
@@ -0,0 +1,312 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpretrain.registry import MODELS
+from .resnet import ResNet
+from .resnext import Bottleneck
+
+
+@MODELS.register_module()
+class RegNet(ResNet):
+ """RegNet backbone.
+
+ More details can be found in `paper `_ .
+
+ Args:
+ arch (dict): The parameter of RegNets.
+ - w0 (int): initial width
+ - wa (float): slope of width
+ - wm (float): quantization parameter to quantize the width
+ - depth (int): depth of the backbone
+ - group_w (int): width of group
+ - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ base_channels (int): Base channels after stem layer.
+ in_channels (int): Number of input image channels. Default: 3.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer. Default: "pytorch".
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters. Default: -1.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+
+ Example:
+ >>> from mmpretrain.models import RegNet
+ >>> import torch
+ >>> self = RegNet(
+ arch=dict(
+ w0=88,
+ wa=26.31,
+ wm=2.25,
+ group_w=48,
+ depth=25,
+ bot_mul=1.0))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 96, 8, 8)
+ (1, 192, 4, 4)
+ (1, 432, 2, 2)
+ (1, 1008, 1, 1)
+ """
+ arch_settings = {
+ 'regnetx_400mf':
+ dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
+ 'regnetx_800mf':
+ dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
+ 'regnetx_1.6gf':
+ dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
+ 'regnetx_3.2gf':
+ dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
+ 'regnetx_4.0gf':
+ dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
+ 'regnetx_6.4gf':
+ dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
+ 'regnetx_8.0gf':
+ dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
+ 'regnetx_12gf':
+ dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ stem_channels=32,
+ base_channels=32,
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ init_cfg=None):
+ super(ResNet, self).__init__(init_cfg)
+
+ # Generate RegNet parameters first
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the' \
+ ' arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ widths, num_stages = self.generate_regnet(
+ arch['w0'],
+ arch['wa'],
+ arch['wm'],
+ arch['depth'],
+ )
+ # Convert to per stage format
+ stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
+ # Generate group widths and bot muls
+ group_widths = [arch['group_w'] for _ in range(num_stages)]
+ self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
+ # Adjust the compatibility of stage_widths and group_widths
+ stage_widths, group_widths = self.adjust_width_group(
+ stage_widths, self.bottleneck_ratio, group_widths)
+
+ # Group params by stage
+ self.stage_widths = stage_widths
+ self.group_widths = group_widths
+ self.depth = sum(stage_blocks)
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ if self.deep_stem:
+ raise NotImplementedError(
+ 'deep_stem has not been implemented for RegNet')
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.stage_blocks = stage_blocks[:num_stages]
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ _in_channels = stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ group_width = self.group_widths[i]
+ width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
+ stage_groups = width // group_width
+
+ res_layer = self.make_res_layer(
+ block=Bottleneck,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=self.stage_widths[i],
+ expansion=1,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ base_channels=self.stage_widths[i],
+ groups=stage_groups,
+ width_per_group=group_width)
+ _in_channels = self.stage_widths[i]
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = stage_widths[-1]
+
+ def _make_stem_layer(self, in_channels, base_channels):
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ base_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, base_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def generate_regnet(self,
+ initial_width,
+ width_slope,
+ width_parameter,
+ depth,
+ divisor=8):
+ """Generates per block width from RegNet parameters.
+
+ Args:
+ initial_width ([int]): Initial width of the backbone
+ width_slope ([float]): Slope of the quantized linear function
+ width_parameter ([int]): Parameter used to quantize the width.
+ depth ([int]): Depth of the backbone.
+ divisor (int): The divisor of channels. Defaults to 8.
+
+ Returns:
+ tuple: tuple containing:
+ - list: Widths of each stage.
+ - int: The number of stages.
+ """
+ assert width_slope >= 0
+ assert initial_width > 0
+ assert width_parameter > 1
+ assert initial_width % divisor == 0
+ widths_cont = np.arange(depth) * width_slope + initial_width
+ ks = np.round(
+ np.log(widths_cont / initial_width) / np.log(width_parameter))
+ widths = initial_width * np.power(width_parameter, ks)
+ widths = np.round(np.divide(widths, divisor)) * divisor
+ num_stages = len(np.unique(widths))
+ widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
+ return widths, num_stages
+
+ @staticmethod
+ def quantize_float(number, divisor):
+ """Converts a float to closest non-zero int divisible by divior.
+
+ Args:
+ number (int): Original number to be quantized.
+ divisor (int): Divisor used to quantize the number.
+
+ Returns:
+ int: quantized number that is divisible by devisor.
+ """
+ return int(round(number / divisor) * divisor)
+
+ def adjust_width_group(self, widths, bottleneck_ratio, groups):
+ """Adjusts the compatibility of widths and groups.
+
+ Args:
+ widths (list[int]): Width of each stage.
+ bottleneck_ratio (float): Bottleneck ratio.
+ groups (int): number of groups in each stage
+
+ Returns:
+ tuple(list): The adjusted widths and groups of each stage.
+ """
+ bottleneck_width = [
+ int(w * b) for w, b in zip(widths, bottleneck_ratio)
+ ]
+ groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
+ bottleneck_width = [
+ self.quantize_float(w_bot, g)
+ for w_bot, g in zip(bottleneck_width, groups)
+ ]
+ widths = [
+ int(w_bot / b)
+ for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
+ ]
+ return widths, groups
+
+ def get_stages_from_blocks(self, widths):
+ """Gets widths/stage_blocks of network at each stage.
+
+ Args:
+ widths (list[int]): Width in each stage.
+
+ Returns:
+ tuple(list): width and depth of each stage
+ """
+ width_diff = [
+ width != width_prev
+ for width, width_prev in zip(widths + [0], [0] + widths)
+ ]
+ stage_widths = [
+ width for width, diff in zip(widths, width_diff[:-1]) if diff
+ ]
+ stage_blocks = np.diff([
+ depth for depth, diff in zip(range(len(width_diff)), width_diff)
+ if diff
+ ]).tolist()
+ return stage_widths, stage_blocks
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/replknet.py b/mmpretrain/models/backbones/replknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dce4154fbe1d95806eec118b69ff70f0d74c1c6
--- /dev/null
+++ b/mmpretrain/models/backbones/replknet.py
@@ -0,0 +1,668 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+
+def conv_bn(in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ dilation=1,
+ norm_cfg=dict(type='BN')):
+ """Construct a sequential conv and bn.
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the convolution.
+ stride (int): stride of the convolution.
+ padding (int): stride of the convolution.
+ groups (int): groups of the convolution.
+ dilation (int): dilation of the convolution. Default to 1.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+
+ Returns:
+ nn.Sequential(): A conv layer and a batch norm layer.
+ """
+ if padding is None:
+ padding = kernel_size // 2
+ result = nn.Sequential()
+ result.add_module(
+ 'conv',
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=False))
+ result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1])
+ return result
+
+
+def conv_bn_relu(in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups,
+ dilation=1):
+ """Construct a sequential conv, bn and relu.
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the convolution.
+ stride (int): stride of the convolution.
+ padding (int): stride of the convolution.
+ groups (int): groups of the convolution.
+ dilation (int): dilation of the convolution. Default to 1.
+
+ Returns:
+ nn.Sequential(): A conv layer, batch norm layer and a relu function.
+ """
+
+ if padding is None:
+ padding = kernel_size // 2
+ result = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ dilation=dilation)
+ result.add_module('nonlinear', nn.ReLU())
+ return result
+
+
+def fuse_bn(conv, bn):
+ """Fuse the parameters in a branch with a conv and bn.
+
+ Args:
+ conv (nn.Conv2d): The convolution module to fuse.
+ bn (nn.BatchNorm2d): The batch normalization to fuse.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
+ fusing the parameters of conv and bn in one branch.
+ The first element is the weight and the second is the bias.
+ """
+ kernel = conv.weight
+ running_mean = bn.running_mean
+ running_var = bn.running_var
+ gamma = bn.weight
+ beta = bn.bias
+ eps = bn.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+
+class ReparamLargeKernelConv(BaseModule):
+ """Super large kernel implemented by with large convolutions.
+
+ Input: Tensor with shape [B, C, H, W].
+ Output: Tensor with shape [B, C, H, W].
+
+ Args:
+ in_channels (int): Dimension of input features.
+ out_channels (int): Dimension of output features.
+ kernel_size (int): kernel_size of the large convolution.
+ stride (int): stride of the large convolution.
+ groups (int): groups of the large convolution.
+ small_kernel (int): kernel_size of the small convolution.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups,
+ small_kernel,
+ small_kernel_merged=False,
+ init_cfg=None):
+ super(ReparamLargeKernelConv, self).__init__(init_cfg)
+ self.kernel_size = kernel_size
+ self.small_kernel = small_kernel
+ self.small_kernel_merged = small_kernel_merged
+ # We assume the conv does not change the feature map size,
+ # so padding = k//2.
+ # Otherwise, you may configure padding as you wish,
+ # and change the padding of small_conv accordingly.
+ padding = kernel_size // 2
+ if small_kernel_merged:
+ self.lkb_reparam = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=1,
+ groups=groups,
+ bias=True)
+ else:
+ self.lkb_origin = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=1,
+ groups=groups)
+ if small_kernel is not None:
+ assert small_kernel <= kernel_size
+ self.small_conv = conv_bn(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=small_kernel,
+ stride=stride,
+ padding=small_kernel // 2,
+ groups=groups,
+ dilation=1)
+
+ def forward(self, inputs):
+ if hasattr(self, 'lkb_reparam'):
+ out = self.lkb_reparam(inputs)
+ else:
+ out = self.lkb_origin(inputs)
+ if hasattr(self, 'small_conv'):
+ out += self.small_conv(inputs)
+ return out
+
+ def get_equivalent_kernel_bias(self):
+ eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
+ if hasattr(self, 'small_conv'):
+ small_k, small_b = fuse_bn(self.small_conv.conv,
+ self.small_conv.bn)
+ eq_b += small_b
+ # add to the central part
+ eq_k += nn.functional.pad(
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
+ return eq_k, eq_b
+
+ def merge_kernel(self):
+ """Switch the model structure from training mode to deployment mode."""
+ if self.small_kernel_merged:
+ return
+ eq_k, eq_b = self.get_equivalent_kernel_bias()
+ self.lkb_reparam = nn.Conv2d(
+ in_channels=self.lkb_origin.conv.in_channels,
+ out_channels=self.lkb_origin.conv.out_channels,
+ kernel_size=self.lkb_origin.conv.kernel_size,
+ stride=self.lkb_origin.conv.stride,
+ padding=self.lkb_origin.conv.padding,
+ dilation=self.lkb_origin.conv.dilation,
+ groups=self.lkb_origin.conv.groups,
+ bias=True)
+
+ self.lkb_reparam.weight.data = eq_k
+ self.lkb_reparam.bias.data = eq_b
+ self.__delattr__('lkb_origin')
+ if hasattr(self, 'small_conv'):
+ self.__delattr__('small_conv')
+
+ self.small_kernel_merged = True
+
+
+class ConvFFN(BaseModule):
+ """Mlp implemented by with 1*1 convolutions.
+
+ Input: Tensor with shape [B, C, H, W].
+ Output: Tensor with shape [B, C, H, W].
+
+ Args:
+ in_channels (int): Dimension of input features.
+ internal_channels (int): Dimension of hidden features.
+ out_channels (int): Dimension of output features.
+ drop_path (float): Stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ internal_channels,
+ out_channels,
+ drop_path,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='GELU'),
+ init_cfg=None):
+ super(ConvFFN, self).__init__(init_cfg)
+ self.drop_path = DropPath(
+ drop_prob=drop_path) if drop_path > 0. else nn.Identity()
+ self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1]
+ self.pw1 = conv_bn(
+ in_channels=in_channels,
+ out_channels=internal_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1)
+ self.pw2 = conv_bn(
+ in_channels=internal_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1)
+ self.nonlinear = build_activation_layer(act_cfg)
+
+ def forward(self, x):
+ out = self.preffn_bn(x)
+ out = self.pw1(out)
+ out = self.nonlinear(out)
+ out = self.pw2(out)
+ return x + self.drop_path(out)
+
+
+class RepLKBlock(BaseModule):
+ """RepLKBlock for RepLKNet backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ dw_channels (int): The intermediate channels of the block,
+ i.e., input channels of the large kernel convolution.
+ block_lk_size (int): size of the super large kernel. Defaults: 31.
+ small_kernel (int): size of the parallel small kernel. Defaults: 5.
+ drop_path (float): Stochastic depth rate. Defaults: 0.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ act_cfg (dict): Config dict for activation layer.
+ Default to ``dict(type='ReLU')``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default to None
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ block_lk_size,
+ small_kernel,
+ drop_path,
+ small_kernel_merged=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ init_cfg=None):
+ super(RepLKBlock, self).__init__(init_cfg)
+ self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
+ self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
+ self.large_kernel = ReparamLargeKernelConv(
+ in_channels=dw_channels,
+ out_channels=dw_channels,
+ kernel_size=block_lk_size,
+ stride=1,
+ groups=dw_channels,
+ small_kernel=small_kernel,
+ small_kernel_merged=small_kernel_merged)
+ self.lk_nonlinear = build_activation_layer(act_cfg)
+ self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1]
+ self.drop_path = DropPath(
+ drop_prob=drop_path) if drop_path > 0. else nn.Identity()
+ # print('drop path:', self.drop_path)
+
+ def forward(self, x):
+ out = self.prelkb_bn(x)
+ out = self.pw1(out)
+ out = self.large_kernel(out)
+ out = self.lk_nonlinear(out)
+ out = self.pw2(out)
+ return x + self.drop_path(out)
+
+
+class RepLKNetStage(BaseModule):
+ """
+ generate RepLKNet blocks for a stage
+ return: RepLKNet blocks
+
+ Args:
+ channels (int): The input channels of the stage.
+ num_blocks (int): The number of blocks of the stage.
+ stage_lk_size (int): size of the super large kernel. Defaults: 31.
+ drop_path (float): Stochastic depth rate. Defaults: 0.
+ small_kernel (int): size of the parallel small kernel. Defaults: 5.
+ dw_ratio (float): The intermediate channels
+ expansion ratio of the block. Defaults: 1.
+ ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default to False.
+ small_kernel_merged (bool): Whether to switch the model structure to
+ deployment mode (merge the small kernel to the large kernel).
+ Default to False.
+ norm_intermediate_features (bool): Construct and config norm layer
+ or not.
+ Using True will normalize the intermediate features for
+ downstream dense prediction tasks.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default to ``dict(type='BN', requires_grad=True)``.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default to None
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_blocks,
+ stage_lk_size,
+ drop_path,
+ small_kernel,
+ dw_ratio=1,
+ ffn_ratio=4,
+ with_cp=False, # train with torch.utils.checkpoint to save memory
+ small_kernel_merged=False,
+ norm_intermediate_features=False,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ super(RepLKNetStage, self).__init__(init_cfg)
+ self.with_cp = with_cp
+ blks = []
+ for i in range(num_blocks):
+ block_drop_path = drop_path[i] if isinstance(drop_path,
+ list) else drop_path
+ # Assume all RepLK Blocks within a stage share the same lk_size.
+ # You may tune it on your own model.
+ replk_block = RepLKBlock(
+ in_channels=channels,
+ dw_channels=int(channels * dw_ratio),
+ block_lk_size=stage_lk_size,
+ small_kernel=small_kernel,
+ drop_path=block_drop_path,
+ small_kernel_merged=small_kernel_merged)
+ convffn_block = ConvFFN(
+ in_channels=channels,
+ internal_channels=int(channels * ffn_ratio),
+ out_channels=channels,
+ drop_path=block_drop_path)
+ blks.append(replk_block)
+ blks.append(convffn_block)
+ self.blocks = nn.ModuleList(blks)
+ if norm_intermediate_features:
+ self.norm = build_norm_layer(norm_cfg, channels)[1]
+ else:
+ self.norm = nn.Identity()
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.with_cp:
+ x = checkpoint.checkpoint(blk, x) # Save training memory
+ else:
+ x = blk(x)
+ return x
+
+
+@MODELS.register_module()
+class RepLKNet(BaseBackbone):
+ """RepLKNet backbone.
+
+ A PyTorch impl of :
+ `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
+ `_
+
+ Args:
+ arch (str | dict): The parameter of RepLKNet.
+ If it's a dict, it should contain the following keys:
+
+ - large_kernel_sizes (Sequence[int]):
+ Large kernel size in each stage.
+ - layers (Sequence[int]): Number of blocks in each stage.
+ - channels (Sequence[int]): Number of channels in each stage.
+ - small_kernel (int): size of the parallel small kernel.
+ - dw_ratio (float): The intermediate channels
+ expansion ratio of the block.
+ in_channels (int): Number of input image channels. Default to 3.
+ ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
+ out_indices (Sequence[int]): Output from which stages.
+ Default to (3, ).
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default to (2, 2, 2, 2).
+ dilations (Sequence[int]): Dilation of each stage.
+ Default to (1, 1, 1, 1).
+ frozen_stages (int): Stages to be frozen
+ (all param fixed). -1 means not freezing any parameters.
+ Default to -1.
+ conv_cfg (dict | None): The config dict for conv layers.
+ Default to None.
+ norm_cfg (dict): The config dict for norm layers.
+ Default to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Default to ``dict(type='ReLU')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default to False.
+ deploy (bool): Whether to switch the model structure to deployment
+ mode. Default to False.
+ norm_intermediate_features (bool): Construct and
+ config norm layer or not.
+ Using True will normalize the intermediate features
+ for downstream dense prediction tasks.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ arch_settings = {
+ '31B':
+ dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[128, 256, 512, 1024],
+ small_kernel=5,
+ dw_ratio=1),
+ '31L':
+ dict(
+ large_kernel_sizes=[31, 29, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[192, 384, 768, 1536],
+ small_kernel=5,
+ dw_ratio=1),
+ 'XL':
+ dict(
+ large_kernel_sizes=[27, 27, 27, 13],
+ layers=[2, 2, 18, 2],
+ channels=[256, 512, 1024, 2048],
+ small_kernel=None,
+ dw_ratio=1.5),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ ffn_ratio=4,
+ out_indices=(3, ),
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ drop_path_rate=0.3,
+ small_kernel_merged=False,
+ norm_intermediate_features=False,
+ norm_eval=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(RepLKNet, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ assert len(arch['layers']) == len(
+ arch['channels']) == len(strides) == len(dilations)
+ assert max(out_indices) < len(arch['layers'])
+
+ self.arch = arch
+ self.in_channels = in_channels
+ self.out_indices = out_indices
+ self.strides = strides
+ self.dilations = dilations
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.with_cp = with_cp
+ self.drop_path_rate = drop_path_rate
+ self.small_kernel_merged = small_kernel_merged
+ self.norm_eval = norm_eval
+ self.norm_intermediate_features = norm_intermediate_features
+
+ self.out_indices = out_indices
+
+ base_width = self.arch['channels'][0]
+ self.norm_intermediate_features = norm_intermediate_features
+ self.num_stages = len(self.arch['layers'])
+ self.stem = nn.ModuleList([
+ conv_bn_relu(
+ in_channels=in_channels,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=base_width),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1),
+ conv_bn_relu(
+ in_channels=base_width,
+ out_channels=base_width,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=base_width)
+ ])
+ # stochastic depth. We set block-wise drop-path rate.
+ # The higher level blocks are more likely to be dropped.
+ # This implementation follows Swin.
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate,
+ sum(self.arch['layers']))
+ ]
+ self.stages = nn.ModuleList()
+ self.transitions = nn.ModuleList()
+ for stage_idx in range(self.num_stages):
+ layer = RepLKNetStage(
+ channels=self.arch['channels'][stage_idx],
+ num_blocks=self.arch['layers'][stage_idx],
+ stage_lk_size=self.arch['large_kernel_sizes'][stage_idx],
+ drop_path=dpr[sum(self.arch['layers'][:stage_idx]
+ ):sum(self.arch['layers'][:stage_idx + 1])],
+ small_kernel=self.arch['small_kernel'],
+ dw_ratio=self.arch['dw_ratio'],
+ ffn_ratio=ffn_ratio,
+ with_cp=with_cp,
+ small_kernel_merged=small_kernel_merged,
+ norm_intermediate_features=(stage_idx in out_indices))
+ self.stages.append(layer)
+ if stage_idx < len(self.arch['layers']) - 1:
+ transition = nn.Sequential(
+ conv_bn_relu(
+ self.arch['channels'][stage_idx],
+ self.arch['channels'][stage_idx + 1],
+ 1,
+ 1,
+ 0,
+ groups=1),
+ conv_bn_relu(
+ self.arch['channels'][stage_idx + 1],
+ self.arch['channels'][stage_idx + 1],
+ 3,
+ stride=2,
+ padding=1,
+ groups=self.arch['channels'][stage_idx + 1]))
+ self.transitions.append(transition)
+
+ def forward_features(self, x):
+ x = self.stem[0](x)
+ for stem_layer in self.stem[1:]:
+ if self.with_cp:
+ x = checkpoint.checkpoint(stem_layer, x) # save memory
+ else:
+ x = stem_layer(x)
+
+ # Need the intermediate feature maps
+ outs = []
+ for stage_idx in range(self.num_stages):
+ x = self.stages[stage_idx](x)
+ if stage_idx in self.out_indices:
+ outs.append(self.stages[stage_idx].norm(x))
+ # For RepLKNet-XL normalize the features
+ # before feeding them into the heads
+ if stage_idx < self.num_stages - 1:
+ x = self.transitions[stage_idx](x)
+ return outs
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return tuple(x)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ for i in range(self.frozen_stages):
+ stage = self.stages[i]
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(RepLKNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def switch_to_deploy(self):
+ for m in self.modules():
+ if hasattr(m, 'merge_kernel'):
+ m.merge_kernel()
+ self.small_kernel_merged = True
diff --git a/mmpretrain/models/backbones/repmlp.py b/mmpretrain/models/backbones/repmlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7c06c4875710b33c57f2794c437034d93169b30
--- /dev/null
+++ b/mmpretrain/models/backbones/repmlp.py
@@ -0,0 +1,578 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Adapted from official impl at https://github.com/DingXiaoH/RepMLP.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
+from mmengine.model import BaseModule, ModuleList, Sequential
+
+from mmpretrain.models.utils import SELayer, to_2tuple
+from mmpretrain.registry import MODELS
+
+
+def fuse_bn(conv_or_fc, bn):
+ """fuse conv and bn."""
+ std = (bn.running_var + bn.eps).sqrt()
+ tmp_weight = bn.weight / std
+ tmp_weight = tmp_weight.reshape(-1, 1, 1, 1)
+
+ if len(tmp_weight) == conv_or_fc.weight.size(0):
+ return (conv_or_fc.weight * tmp_weight,
+ bn.bias - bn.running_mean * bn.weight / std)
+ else:
+ # in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights
+ # are different.
+ repeat_times = conv_or_fc.weight.size(0) // len(tmp_weight)
+ repeated = tmp_weight.repeat_interleave(repeat_times, 0)
+ fused_weight = conv_or_fc.weight * repeated
+ bias = bn.bias - bn.running_mean * bn.weight / std
+ fused_bias = (bias).repeat_interleave(repeat_times, 0)
+ return (fused_weight, fused_bias)
+
+
+class PatchEmbed(_PatchEmbed):
+ """Image to Patch Embedding.
+
+ Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP
+ have ReLu and do not convert output tensor into shape (N, L, C).
+
+ Args:
+ in_channels (int): The num of input channels. Default: 3
+ embed_dims (int): The dimensions of embedding. Default: 768
+ conv_type (str): The type of convolution
+ to generate patch embedding. Default: "Conv2d".
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
+ stride (int): The slide stride of embedding conv.
+ Default: 16.
+ padding (int | tuple | string): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int): The dilation rate of embedding conv. Default: 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ input_size (int | tuple | None): The size of input, which will be
+ used to calculate the out size. Only works when `dynamic_size`
+ is False. Default: None.
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(PatchEmbed, self).__init__(*args, **kwargs)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+ - x (Tensor): The output tensor.
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adaptive_padding:
+ x = self.adaptive_padding(x)
+
+ x = self.projection(x)
+ if self.norm is not None:
+ x = self.norm(x)
+ x = self.relu(x)
+ out_size = (x.shape[2], x.shape[3])
+ return x, out_size
+
+
+class GlobalPerceptron(SELayer):
+ """GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
+
+ Args:
+ input_channels (int): The number of input (and output) channels
+ in the GlobalPerceptron.
+ ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate
+ channel will be ``make_divisible(channels // ratio, divisor)``.
+ """
+
+ def __init__(self, input_channels: int, ratio: int, **kwargs) -> None:
+ super(GlobalPerceptron, self).__init__(
+ channels=input_channels,
+ ratio=ratio,
+ return_weight=True,
+ act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
+ **kwargs)
+
+
+class RepMLPBlock(BaseModule):
+ """Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron.
+
+ Args:
+ channels (int): The number of input and the output channels of the
+ block.
+ path_h (int): The height of patches.
+ path_w (int): The weidth of patches.
+ reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
+ GlobalPerceptron. Default: None.
+ globalperceptron_ratio (int): The reducation ratio in the
+ GlobalPerceptron. Default: 4.
+ num_sharesets (int): The number of sharesets in the
+ PartitionPerceptron. Default 1.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ deploy (bool): Whether to switch the model structure to
+ deployment mode. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ path_h,
+ path_w,
+ reparam_conv_kernels=None,
+ globalperceptron_ratio=4,
+ num_sharesets=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ deploy=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.deploy = deploy
+ self.channels = channels
+ self.num_sharesets = num_sharesets
+ self.path_h, self.path_w = path_h, path_w
+ # the input channel of fc3
+ self._path_vec_channles = path_h * path_w * num_sharesets
+
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.gp = GlobalPerceptron(
+ input_channels=channels, ratio=globalperceptron_ratio)
+
+ # using a conv layer to implement a fc layer
+ self.fc3 = build_conv_layer(
+ conv_cfg,
+ in_channels=self._path_vec_channles,
+ out_channels=self._path_vec_channles,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=deploy,
+ groups=num_sharesets)
+ if deploy:
+ self.fc3_bn = nn.Identity()
+ else:
+ norm_layer = build_norm_layer(norm_cfg, num_sharesets)[1]
+ self.add_module('fc3_bn', norm_layer)
+
+ self.reparam_conv_kernels = reparam_conv_kernels
+ if not deploy and reparam_conv_kernels is not None:
+ for k in reparam_conv_kernels:
+ conv_branch = ConvModule(
+ in_channels=num_sharesets,
+ out_channels=num_sharesets,
+ kernel_size=k,
+ stride=1,
+ padding=k // 2,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ groups=num_sharesets,
+ act_cfg=None)
+ self.__setattr__('repconv{}'.format(k), conv_branch)
+
+ def partition(self, x, h_parts, w_parts):
+ # convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w)
+ x = x.reshape(-1, self.channels, h_parts, self.path_h, w_parts,
+ self.path_w)
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ return x
+
+ def partition_affine(self, x, h_parts, w_parts):
+ """perform Partition Perceptron."""
+ fc_inputs = x.reshape(-1, self._path_vec_channles, 1, 1)
+ out = self.fc3(fc_inputs)
+ out = out.reshape(-1, self.num_sharesets, self.path_h, self.path_w)
+ out = self.fc3_bn(out)
+ out = out.reshape(-1, h_parts, w_parts, self.num_sharesets,
+ self.path_h, self.path_w)
+ return out
+
+ def forward(self, inputs):
+ # Global Perceptron
+ global_vec = self.gp(inputs)
+
+ origin_shape = inputs.size()
+ h_parts = origin_shape[2] // self.path_h
+ w_parts = origin_shape[3] // self.path_w
+
+ partitions = self.partition(inputs, h_parts, w_parts)
+
+ # Channel Perceptron
+ fc3_out = self.partition_affine(partitions, h_parts, w_parts)
+
+ # perform Local Perceptron
+ if self.reparam_conv_kernels is not None and not self.deploy:
+ conv_inputs = partitions.reshape(-1, self.num_sharesets,
+ self.path_h, self.path_w)
+ conv_out = 0
+ for k in self.reparam_conv_kernels:
+ conv_branch = self.__getattr__('repconv{}'.format(k))
+ conv_out += conv_branch(conv_inputs)
+ conv_out = conv_out.reshape(-1, h_parts, w_parts,
+ self.num_sharesets, self.path_h,
+ self.path_w)
+ fc3_out += conv_out
+
+ # N, h_parts, w_parts, num_sharesets, out_h, out_w
+ fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5)
+ out = fc3_out.reshape(*origin_shape)
+ out = out * global_vec
+ return out
+
+ def get_equivalent_fc3(self):
+ """get the equivalent fc3 weight and bias."""
+ fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn)
+ if self.reparam_conv_kernels is not None:
+ largest_k = max(self.reparam_conv_kernels)
+ largest_branch = self.__getattr__('repconv{}'.format(largest_k))
+ total_kernel, total_bias = fuse_bn(largest_branch.conv,
+ largest_branch.bn)
+ for k in self.reparam_conv_kernels:
+ if k != largest_k:
+ k_branch = self.__getattr__('repconv{}'.format(k))
+ kernel, bias = fuse_bn(k_branch.conv, k_branch.bn)
+ total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4)
+ total_bias += bias
+ rep_weight, rep_bias = self._convert_conv_to_fc(
+ total_kernel, total_bias)
+ final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight
+ final_fc3_bias = rep_bias + fc_bias
+ else:
+ final_fc3_weight = fc_weight
+ final_fc3_bias = fc_bias
+ return final_fc3_weight, final_fc3_bias
+
+ def local_inject(self):
+ """inject the Local Perceptron into Partition Perceptron."""
+ self.deploy = True
+ # Locality Injection
+ fc3_weight, fc3_bias = self.get_equivalent_fc3()
+ # Remove Local Perceptron
+ if self.reparam_conv_kernels is not None:
+ for k in self.reparam_conv_kernels:
+ self.__delattr__('repconv{}'.format(k))
+ self.__delattr__('fc3')
+ self.__delattr__('fc3_bn')
+ self.fc3 = build_conv_layer(
+ self.conv_cfg,
+ self._path_vec_channles,
+ self._path_vec_channles,
+ 1,
+ 1,
+ 0,
+ bias=True,
+ groups=self.num_sharesets)
+ self.fc3_bn = nn.Identity()
+ self.fc3.weight.data = fc3_weight
+ self.fc3.bias.data = fc3_bias
+
+ def _convert_conv_to_fc(self, conv_kernel, conv_bias):
+ """convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1."""
+ in_channels = torch.eye(self.path_h * self.path_w).repeat(
+ 1, self.num_sharesets).reshape(self.path_h * self.path_w,
+ self.num_sharesets, self.path_h,
+ self.path_w).to(conv_kernel.device)
+ fc_k = F.conv2d(
+ in_channels,
+ conv_kernel,
+ padding=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2),
+ groups=self.num_sharesets)
+ fc_k = fc_k.reshape(self.path_w * self.path_w, self.num_sharesets *
+ self.path_h * self.path_w).t()
+ fc_bias = conv_bias.repeat_interleave(self.path_h * self.path_w)
+ return fc_k, fc_bias
+
+
+class RepMLPNetUnit(BaseModule):
+ """A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN].
+
+ Args:
+ channels (int): The number of input and the output channels of the
+ unit.
+ path_h (int): The height of patches.
+ path_w (int): The weidth of patches.
+ reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
+ GlobalPerceptron. Default: None.
+ globalperceptron_ratio (int): The reducation ratio in the
+ GlobalPerceptron. Default: 4.
+ num_sharesets (int): The number of sharesets in the
+ PartitionPerceptron. Default 1.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ deploy (bool): Whether to switch the model structure to
+ deployment mode. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ path_h,
+ path_w,
+ reparam_conv_kernels,
+ globalperceptron_ratio,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ ffn_expand=4,
+ num_sharesets=1,
+ deploy=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.repmlp_block = RepMLPBlock(
+ channels=channels,
+ path_h=path_h,
+ path_w=path_w,
+ reparam_conv_kernels=reparam_conv_kernels,
+ globalperceptron_ratio=globalperceptron_ratio,
+ num_sharesets=num_sharesets,
+ deploy=deploy)
+ self.ffn_block = ConvFFN(channels, channels * ffn_expand)
+ norm1 = build_norm_layer(norm_cfg, channels)[1]
+ self.add_module('norm1', norm1)
+ norm2 = build_norm_layer(norm_cfg, channels)[1]
+ self.add_module('norm2', norm2)
+
+ def forward(self, x):
+ y = x + self.repmlp_block(self.norm1(x))
+ out = y + self.ffn_block(self.norm2(y))
+ return out
+
+
+class ConvFFN(nn.Module):
+ """ConvFFN implemented by using point-wise convs."""
+
+ def __init__(self,
+ in_channels,
+ hidden_channels=None,
+ out_channels=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='GELU')):
+ super().__init__()
+ out_features = out_channels or in_channels
+ hidden_features = hidden_channels or in_channels
+ self.ffn_fc1 = ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ self.ffn_fc2 = ConvModule(
+ in_channels=hidden_features,
+ out_channels=out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ self.act = build_activation_layer(act_cfg)
+
+ def forward(self, x):
+ x = self.ffn_fc1(x)
+ x = self.act(x)
+ x = self.ffn_fc2(x)
+ return x
+
+
+@MODELS.register_module()
+class RepMLPNet(BaseModule):
+ """RepMLPNet backbone.
+
+ A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into
+ Fully-connected Layers for Image Recognition
+ `_
+
+ Args:
+ arch (str | dict): RepMLP architecture. If use string, choose
+ from 'base' and 'b'. If use dict, it should have below keys:
+
+ - channels (List[int]): Number of blocks in each stage.
+ - depths (List[int]): The number of blocks in each branch.
+ - sharesets_nums (List[int]): RepVGG Block that declares
+ the need to apply group convolution.
+
+ img_size (int | tuple): The size of input image. Defaults: 224.
+ in_channels (int): Number of input image channels. Default: 3.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 4.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: ``(3, )``.
+ reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
+ GlobalPerceptron. Default: None.
+ globalperceptron_ratio (int): The reducation ratio in the
+ GlobalPerceptron. Default: 4.
+ num_sharesets (int): The number of sharesets in the
+ PartitionPerceptron. Default 1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ Default: dict(type='BN', requires_grad=True).
+ patch_cfg (dict): Extra config dict for patch embedding.
+ Defaults to an empty dict.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ deploy (bool): Whether to switch the model structure to deployment
+ mode. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+ arch_zoo = {
+ **dict.fromkeys(['b', 'base'],
+ {'channels': [96, 192, 384, 768],
+ 'depths': [2, 2, 12, 2],
+ 'sharesets_nums': [1, 4, 32, 128]}),
+ } # yapf: disable
+
+ num_extra_tokens = 0 # there is no cls-token in RepMLP
+
+ def __init__(self,
+ arch,
+ img_size=224,
+ in_channels=3,
+ patch_size=4,
+ out_indices=(3, ),
+ reparam_conv_kernels=(3, ),
+ globalperceptron_ratio=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ patch_cfg=dict(),
+ final_norm=True,
+ deploy=False,
+ init_cfg=None):
+ super(RepMLPNet, self).__init__(init_cfg=init_cfg)
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {'channels', 'depths', 'sharesets_nums'}
+ assert isinstance(arch, dict) and set(arch) == essential_keys, \
+ f'Custom arch needs a dict with keys {essential_keys}.'
+ self.arch_settings = arch
+
+ self.img_size = to_2tuple(img_size)
+ self.patch_size = to_2tuple(patch_size)
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.num_stage = len(self.arch_settings['channels'])
+ for value in self.arch_settings.values():
+ assert isinstance(value, list) and len(value) == self.num_stage, (
+ 'Length of setting item in arch dict must be type of list and'
+ ' have the same length.')
+
+ self.channels = self.arch_settings['channels']
+ self.depths = self.arch_settings['depths']
+ self.sharesets_nums = self.arch_settings['sharesets_nums']
+
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=self.img_size,
+ embed_dims=self.channels[0],
+ conv_type='Conv2d',
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ norm_cfg=self.norm_cfg,
+ bias=False)
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+
+ self.patch_hs = [
+ self.patch_resolution[0] // 2**i for i in range(self.num_stage)
+ ]
+ self.patch_ws = [
+ self.patch_resolution[1] // 2**i for i in range(self.num_stage)
+ ]
+
+ self.stages = ModuleList()
+ self.downsample_layers = ModuleList()
+ for stage_idx in range(self.num_stage):
+ # make stage layers
+ _stage_cfg = dict(
+ channels=self.channels[stage_idx],
+ path_h=self.patch_hs[stage_idx],
+ path_w=self.patch_ws[stage_idx],
+ reparam_conv_kernels=reparam_conv_kernels,
+ globalperceptron_ratio=globalperceptron_ratio,
+ norm_cfg=self.norm_cfg,
+ ffn_expand=4,
+ num_sharesets=self.sharesets_nums[stage_idx],
+ deploy=deploy)
+ stage_blocks = [
+ RepMLPNetUnit(**_stage_cfg)
+ for _ in range(self.depths[stage_idx])
+ ]
+ self.stages.append(Sequential(*stage_blocks))
+
+ # make downsample layers
+ if stage_idx < self.num_stage - 1:
+ self.downsample_layers.append(
+ ConvModule(
+ in_channels=self.channels[stage_idx],
+ out_channels=self.channels[stage_idx + 1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True))
+
+ self.out_indice = out_indices
+
+ if final_norm:
+ norm_layer = build_norm_layer(norm_cfg, self.channels[-1])[1]
+ else:
+ norm_layer = nn.Identity()
+ self.add_module('final_norm', norm_layer)
+
+ def forward(self, x):
+ assert x.shape[2:] == self.img_size, \
+ "The Rep-MLP doesn't support dynamic input shape. " \
+ f'Please input images with shape {self.img_size}'
+
+ outs = []
+
+ x, _ = self.patch_embed(x)
+ for i, stage in enumerate(self.stages):
+ x = stage(x)
+
+ # downsample after each stage except last stage
+ if i < len(self.stages) - 1:
+ downsample = self.downsample_layers[i]
+ x = downsample(x)
+
+ if i in self.out_indice:
+ if self.final_norm and i == len(self.stages) - 1:
+ out = self.final_norm(x)
+ else:
+ out = x
+ outs.append(out)
+
+ return tuple(outs)
+
+ def switch_to_deploy(self):
+ for m in self.modules():
+ if hasattr(m, 'local_inject'):
+ m.local_inject()
diff --git a/mmpretrain/models/backbones/repvgg.py b/mmpretrain/models/backbones/repvgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..67c9d147546eb2839a44749040a1a787ee5ce0ea
--- /dev/null
+++ b/mmpretrain/models/backbones/repvgg.py
@@ -0,0 +1,622 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmengine.model import BaseModule, Sequential
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+from torch import nn
+
+from mmpretrain.registry import MODELS
+from ..utils.se_layer import SELayer
+from .base_backbone import BaseBackbone
+
+
+class RepVGGBlock(BaseModule):
+ """RepVGG block for RepVGG backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ out_channels (int): The output channels of the block.
+ stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1.
+ padding (int): Padding of the 3x3 convolution layer.
+ dilation (int): Dilation of the 3x3 convolution layer.
+ groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1.
+ padding_mode (str): Padding mode of the 3x3 convolution layer.
+ Default: 'zeros'.
+ se_cfg (None or dict): The configuration of the se module.
+ Default: None.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ deploy (bool): Whether to switch the model structure to
+ deployment mode. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ padding=1,
+ dilation=1,
+ groups=1,
+ padding_mode='zeros',
+ se_cfg=None,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ deploy=False,
+ init_cfg=None):
+ super(RepVGGBlock, self).__init__(init_cfg)
+
+ assert se_cfg is None or isinstance(se_cfg, dict)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.se_cfg = se_cfg
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.deploy = deploy
+
+ if deploy:
+ self.branch_reparam = build_conv_layer(
+ conv_cfg,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=True,
+ padding_mode=padding_mode)
+ else:
+ # judge if input shape and output shape are the same.
+ # If true, add a normalized identity shortcut.
+ if out_channels == in_channels and stride == 1 and \
+ padding == dilation:
+ self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
+ else:
+ self.branch_norm = None
+
+ self.branch_3x3 = self.create_conv_bn(
+ kernel_size=3,
+ dilation=dilation,
+ padding=padding,
+ )
+ self.branch_1x1 = self.create_conv_bn(kernel_size=1)
+
+ if se_cfg is not None:
+ self.se_layer = SELayer(channels=out_channels, **se_cfg)
+ else:
+ self.se_layer = None
+
+ self.act = build_activation_layer(act_cfg)
+
+ def create_conv_bn(self, kernel_size, dilation=1, padding=0):
+ conv_bn = Sequential()
+ conv_bn.add_module(
+ 'conv',
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=kernel_size,
+ stride=self.stride,
+ dilation=dilation,
+ padding=padding,
+ groups=self.groups,
+ bias=False))
+ conv_bn.add_module(
+ 'norm',
+ build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1])
+
+ return conv_bn
+
+ def forward(self, x):
+
+ def _inner_forward(inputs):
+ if self.deploy:
+ return self.branch_reparam(inputs)
+
+ if self.branch_norm is None:
+ branch_norm_out = 0
+ else:
+ branch_norm_out = self.branch_norm(inputs)
+
+ inner_out = self.branch_3x3(inputs) + self.branch_1x1(
+ inputs) + branch_norm_out
+
+ if self.se_cfg is not None:
+ inner_out = self.se_layer(inner_out)
+
+ return inner_out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.act(out)
+
+ return out
+
+ def switch_to_deploy(self):
+ """Switch the model structure from training mode to deployment mode."""
+ if self.deploy:
+ return
+ assert self.norm_cfg['type'] == 'BN', \
+ "Switch is not allowed when norm_cfg['type'] != 'BN'."
+
+ reparam_weight, reparam_bias = self.reparameterize()
+ self.branch_reparam = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.out_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ bias=True)
+ self.branch_reparam.weight.data = reparam_weight
+ self.branch_reparam.bias.data = reparam_bias
+
+ for param in self.parameters():
+ param.detach_()
+ delattr(self, 'branch_3x3')
+ delattr(self, 'branch_1x1')
+ delattr(self, 'branch_norm')
+
+ self.deploy = True
+
+ def reparameterize(self):
+ """Fuse all the parameters of all branches.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
+ branches. the first element is the weights and the second is
+ the bias.
+ """
+ weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3)
+ weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1)
+ # pad a conv1x1 weight to a conv3x3 weight
+ weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0)
+
+ weight_norm, bias_norm = 0, 0
+ if self.branch_norm:
+ tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm)
+ weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn)
+
+ return (weight_3x3 + weight_1x1 + weight_norm,
+ bias_3x3 + bias_1x1 + bias_norm)
+
+ def _fuse_conv_bn(self, branch):
+ """Fuse the parameters in a branch with a conv and bn.
+
+ Args:
+ branch (mmcv.runner.Sequential): A branch with conv and bn.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
+ fusing the parameters of conv and bn in one branch.
+ The first element is the weight and the second is the bias.
+ """
+ if branch is None:
+ return 0, 0
+ conv_weight = branch.conv.weight
+ running_mean = branch.norm.running_mean
+ running_var = branch.norm.running_var
+ gamma = branch.norm.weight
+ beta = branch.norm.bias
+ eps = branch.norm.eps
+
+ std = (running_var + eps).sqrt()
+ fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight
+ fused_bias = -running_mean * gamma / std + beta
+
+ return fused_weight, fused_bias
+
+ def _norm_to_conv3x3(self, branch_nrom):
+ """Convert a norm layer to a conv3x3-bn sequence.
+
+ Args:
+ branch (nn.BatchNorm2d): A branch only with bn in the block.
+
+ Returns:
+ tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and
+ bn.
+ """
+ input_dim = self.in_channels // self.groups
+ conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3),
+ dtype=branch_nrom.weight.dtype)
+
+ for i in range(self.in_channels):
+ conv_weight[i, i % input_dim, 1, 1] = 1
+ conv_weight = conv_weight.to(branch_nrom.weight.device)
+
+ tmp_conv3x3 = self.create_conv_bn(kernel_size=3)
+ tmp_conv3x3.conv.weight.data = conv_weight
+ tmp_conv3x3.norm = branch_nrom
+ return tmp_conv3x3
+
+
+class MTSPPF(BaseModule):
+ """MTSPPF block for YOLOX-PAI RepVGG backbone.
+
+ Args:
+ in_channels (int): The input channels of the block.
+ out_channels (int): The output channels of the block.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of pooling. Default: 5.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ kernel_size=5):
+ super().__init__()
+ hidden_features = in_channels // 2 # hidden channels
+ self.conv1 = ConvModule(
+ in_channels,
+ hidden_features,
+ 1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.conv2 = ConvModule(
+ hidden_features * 4,
+ out_channels,
+ 1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.maxpool = nn.MaxPool2d(
+ kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ y1 = self.maxpool(x)
+ y2 = self.maxpool(y1)
+ return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))
+
+
+@MODELS.register_module()
+class RepVGG(BaseBackbone):
+ """RepVGG backbone.
+
+ A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again
+ `_
+
+ Args:
+ arch (str | dict): RepVGG architecture. If use string, choose from
+ 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2',
+ 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should
+ have below keys:
+
+ - **num_blocks** (Sequence[int]): Number of blocks in each stage.
+ - **width_factor** (Sequence[float]): Width deflator in each stage.
+ - **group_layer_map** (dict | None): RepVGG Block that declares
+ the need to apply group convolution.
+ - **se_cfg** (dict | None): SE Layer config.
+ - **stem_channels** (int, optional): The stem channels, the final
+ stem channels will be
+ ``min(stem_channels, base_channels*width_factor[0])``.
+ If not set here, 64 is used by default in the code.
+
+ in_channels (int): Number of input image channels. Defaults to 3.
+ base_channels (int): Base channels of RepVGG backbone, work with
+ width_factor together. Defaults to 64.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to ``(3, )``.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Defaults to ``(2, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Defaults to ``(1, 1, 1, 1)``.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters. Defaults to -1.
+ conv_cfg (dict | None): The config dict for conv layers.
+ Defaults to None.
+ norm_cfg (dict): The config dict for norm layers.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='ReLU')``.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ deploy (bool): Whether to switch the model structure to deployment
+ mode. Defaults to False.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ add_ppf (bool): Whether to use the MTSPPF block. Defaults to False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
+ g2_layer_map = {layer: 2 for layer in groupwise_layers}
+ g4_layer_map = {layer: 4 for layer in groupwise_layers}
+
+ arch_settings = {
+ 'A0':
+ dict(
+ num_blocks=[2, 4, 14, 1],
+ width_factor=[0.75, 0.75, 0.75, 2.5],
+ group_layer_map=None,
+ se_cfg=None),
+ 'A1':
+ dict(
+ num_blocks=[2, 4, 14, 1],
+ width_factor=[1, 1, 1, 2.5],
+ group_layer_map=None,
+ se_cfg=None),
+ 'A2':
+ dict(
+ num_blocks=[2, 4, 14, 1],
+ width_factor=[1.5, 1.5, 1.5, 2.75],
+ group_layer_map=None,
+ se_cfg=None),
+ 'B0':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[1, 1, 1, 2.5],
+ group_layer_map=None,
+ se_cfg=None,
+ stem_channels=64),
+ 'B1':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2, 2, 2, 4],
+ group_layer_map=None,
+ se_cfg=None),
+ 'B1g2':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2, 2, 2, 4],
+ group_layer_map=g2_layer_map,
+ se_cfg=None),
+ 'B1g4':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2, 2, 2, 4],
+ group_layer_map=g4_layer_map,
+ se_cfg=None),
+ 'B2':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2.5, 2.5, 2.5, 5],
+ group_layer_map=None,
+ se_cfg=None),
+ 'B2g2':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2.5, 2.5, 2.5, 5],
+ group_layer_map=g2_layer_map,
+ se_cfg=None),
+ 'B2g4':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[2.5, 2.5, 2.5, 5],
+ group_layer_map=g4_layer_map,
+ se_cfg=None),
+ 'B3':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[3, 3, 3, 5],
+ group_layer_map=None,
+ se_cfg=None),
+ 'B3g2':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[3, 3, 3, 5],
+ group_layer_map=g2_layer_map,
+ se_cfg=None),
+ 'B3g4':
+ dict(
+ num_blocks=[4, 6, 16, 1],
+ width_factor=[3, 3, 3, 5],
+ group_layer_map=g4_layer_map,
+ se_cfg=None),
+ 'D2se':
+ dict(
+ num_blocks=[8, 14, 24, 1],
+ width_factor=[2.5, 2.5, 2.5, 5],
+ group_layer_map=None,
+ se_cfg=dict(ratio=16, divisor=1)),
+ 'yolox-pai-small':
+ dict(
+ num_blocks=[3, 5, 7, 3],
+ width_factor=[1, 1, 1, 1],
+ group_layer_map=None,
+ se_cfg=None,
+ stem_channels=32),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ base_channels=64,
+ out_indices=(3, ),
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False,
+ deploy=False,
+ norm_eval=False,
+ add_ppf=False,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]):
+ super(RepVGG, self).__init__(init_cfg)
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise TypeError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ assert len(arch['num_blocks']) == len(
+ arch['width_factor']) == len(strides) == len(dilations)
+ assert max(out_indices) < len(arch['num_blocks'])
+ if arch['group_layer_map'] is not None:
+ assert max(arch['group_layer_map'].keys()) <= sum(
+ arch['num_blocks'])
+
+ if arch['se_cfg'] is not None:
+ assert isinstance(arch['se_cfg'], dict)
+
+ self.base_channels = base_channels
+ self.arch = arch
+ self.in_channels = in_channels
+ self.out_indices = out_indices
+ self.strides = strides
+ self.dilations = dilations
+ self.deploy = deploy
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+
+ # defaults to 64 to prevert BC-breaking if stem_channels
+ # not in arch dict;
+ # the stem channels should not be larger than that of stage1.
+ channels = min(
+ arch.get('stem_channels', 64),
+ int(self.base_channels * self.arch['width_factor'][0]))
+ self.stem = RepVGGBlock(
+ self.in_channels,
+ channels,
+ stride=2,
+ se_cfg=arch['se_cfg'],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ deploy=deploy)
+
+ next_create_block_idx = 1
+ self.stages = []
+ for i in range(len(arch['num_blocks'])):
+ num_blocks = self.arch['num_blocks'][i]
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ out_channels = int(self.base_channels * 2**i *
+ self.arch['width_factor'][i])
+
+ stage, next_create_block_idx = self._make_stage(
+ channels, out_channels, num_blocks, stride, dilation,
+ next_create_block_idx, init_cfg)
+ stage_name = f'stage_{i + 1}'
+ self.add_module(stage_name, stage)
+ self.stages.append(stage_name)
+
+ channels = out_channels
+
+ if add_ppf:
+ self.ppf = MTSPPF(
+ out_channels,
+ out_channels,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ kernel_size=5)
+ else:
+ self.ppf = nn.Identity()
+
+ def _make_stage(self, in_channels, out_channels, num_blocks, stride,
+ dilation, next_create_block_idx, init_cfg):
+ strides = [stride] + [1] * (num_blocks - 1)
+ dilations = [dilation] * num_blocks
+
+ blocks = []
+ for i in range(num_blocks):
+ groups = self.arch['group_layer_map'].get(
+ next_create_block_idx,
+ 1) if self.arch['group_layer_map'] is not None else 1
+ blocks.append(
+ RepVGGBlock(
+ in_channels,
+ out_channels,
+ stride=strides[i],
+ padding=dilations[i],
+ dilation=dilations[i],
+ groups=groups,
+ se_cfg=self.arch['se_cfg'],
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ deploy=self.deploy,
+ init_cfg=init_cfg))
+ in_channels = out_channels
+ next_create_block_idx += 1
+
+ return Sequential(*blocks), next_create_block_idx
+
+ def forward(self, x):
+ x = self.stem(x)
+ outs = []
+ for i, stage_name in enumerate(self.stages):
+ stage = getattr(self, stage_name)
+ x = stage(x)
+ if i + 1 == len(self.stages):
+ x = self.ppf(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ for i in range(self.frozen_stages):
+ stage = getattr(self, f'stage_{i+1}')
+ stage.eval()
+ for param in stage.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(RepVGG, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def switch_to_deploy(self):
+ for m in self.modules():
+ if isinstance(m, RepVGGBlock):
+ m.switch_to_deploy()
+ self.deploy = True
diff --git a/mmpretrain/models/backbones/res2net.py b/mmpretrain/models/backbones/res2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e9bb6df37a2d2c9d19e613faa50ce0103aff357
--- /dev/null
+++ b/mmpretrain/models/backbones/res2net.py
@@ -0,0 +1,317 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmengine.model import ModuleList, Sequential
+
+from mmpretrain.registry import MODELS
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottle2neck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scales=4,
+ base_width=26,
+ base_channels=64,
+ stage_type='normal',
+ **kwargs):
+ """Bottle2neck block for Res2Net."""
+ super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs)
+ assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
+
+ mid_channels = out_channels // self.expansion
+ width = int(math.floor(mid_channels * (base_width / base_channels)))
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width * scales, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ width * scales,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+
+ if stage_type == 'stage':
+ self.pool = nn.AvgPool2d(
+ kernel_size=3, stride=self.conv2_stride, padding=1)
+
+ self.convs = ModuleList()
+ self.bns = ModuleList()
+ for i in range(scales - 1):
+ self.convs.append(
+ build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False))
+ self.bns.append(
+ build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width * scales,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.stage_type = stage_type
+ self.scales = scales
+ self.width = width
+ delattr(self, 'conv2')
+ delattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ spx = torch.split(out, self.width, 1)
+ sp = self.convs[0](spx[0].contiguous())
+ sp = self.relu(self.bns[0](sp))
+ out = sp
+ for i in range(1, self.scales - 1):
+ if self.stage_type == 'stage':
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp.contiguous())
+ sp = self.relu(self.bns[i](sp))
+ out = torch.cat((out, sp), 1)
+
+ if self.stage_type == 'normal' and self.scales != 1:
+ out = torch.cat((out, spx[self.scales - 1]), 1)
+ elif self.stage_type == 'stage' and self.scales != 1:
+ out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Res2Layer(Sequential):
+ """Res2Layer to build Res2Net style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck. Defaults to True.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ scales (int): Scales used in Res2Net. Default: 4
+ base_width (int): Basic width of each scale. Default: 26
+ drop_path_rate (float or np.ndarray): stochastic depth rate.
+ Default: 0.
+ """
+
+ def __init__(self,
+ block,
+ in_channels,
+ out_channels,
+ num_blocks,
+ stride=1,
+ avg_down=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ scales=4,
+ base_width=26,
+ drop_path_rate=0.0,
+ **kwargs):
+ self.block = block
+
+ if isinstance(drop_path_rate, float):
+ drop_path_rate = [drop_path_rate] * num_blocks
+
+ assert len(drop_path_rate
+ ) == num_blocks, 'Please check the length of drop_path_rate'
+
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ if avg_down:
+ downsample = nn.Sequential(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False),
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1],
+ )
+ else:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1],
+ )
+
+ layers = []
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ stage_type='stage',
+ drop_path_rate=drop_path_rate[0],
+ **kwargs))
+ in_channels = out_channels
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ drop_path_rate=drop_path_rate[i],
+ **kwargs))
+ super(Res2Layer, self).__init__(*layers)
+
+
+@MODELS.register_module()
+class Res2Net(ResNet):
+ """Res2Net backbone.
+
+ A PyTorch implement of : `Res2Net: A New Multi-scale Backbone
+ Architecture `_
+
+ Args:
+ depth (int): Depth of Res2Net, choose from {50, 101, 152}.
+ scales (int): Scales used in Res2Net. Defaults to 4.
+ base_width (int): Basic width of each scale. Defaults to 26.
+ in_channels (int): Number of input image channels. Defaults to 3.
+ num_stages (int): Number of Res2Net stages. Defaults to 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Defaults to ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Defaults to ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to ``(3, )``.
+ style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer. Defaults to "pytorch".
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Defaults to True.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck. Defaults to True.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Defaults to ``dict(type='BN', requires_grad=True)``.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Defaults to True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+
+ Example:
+ >>> from mmpretrain.models import Res2Net
+ >>> import torch
+ >>> model = Res2Net(depth=50,
+ ... scales=4,
+ ... base_width=26,
+ ... out_indices=(0, 1, 2, 3))
+ >>> model.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = model.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottle2neck, (3, 4, 6, 3)),
+ 101: (Bottle2neck, (3, 4, 23, 3)),
+ 152: (Bottle2neck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ scales=4,
+ base_width=26,
+ style='pytorch',
+ deep_stem=True,
+ avg_down=True,
+ init_cfg=None,
+ **kwargs):
+ self.scales = scales
+ self.base_width = base_width
+ super(Res2Net, self).__init__(
+ style=style,
+ deep_stem=deep_stem,
+ avg_down=avg_down,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return Res2Layer(
+ scales=self.scales,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmpretrain/models/backbones/resnest.py b/mmpretrain/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bb438f042d606946fd7b69d73568f28563e0efa
--- /dev/null
+++ b/mmpretrain/models/backbones/resnest.py
@@ -0,0 +1,339 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpretrain.registry import MODELS
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResLayer, ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN')):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module, optional): downsample operation on identity
+ branch. Default: None
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ conv_cfg (dict, optional): dictionary to construct and config conv
+ layer. Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ groups=1,
+ width_per_group=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
+
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ # For ResNet bottleneck, middle channels are determined by expansion
+ # and out_channels, but for ResNeXt bottleneck, it is determined by
+ # groups and width_per_group and the stage it is located in.
+ if groups != 1:
+ assert self.mid_channels % base_channels == 0
+ self.mid_channels = (
+ groups * width_per_group * self.mid_channels // base_channels)
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = SplitAttentionConv2d(
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@MODELS.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {50, 101, 152, 200}.
+ groups (int): Groups of conv2 in Bottleneck. Default: 32.
+ width_per_group (int): Width per group of conv2 in Bottleneck.
+ Default: 4.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of SplitAttentionConv2d.
+ Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3)),
+ 269: (Bottleneck, (3, 30, 48, 8))
+ }
+
+ def __init__(self,
+ depth,
+ groups=1,
+ width_per_group=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.width_per_group = width_per_group
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(depth=depth, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(
+ groups=self.groups,
+ width_per_group=self.width_per_group,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a254f7c2b76f03974e05194b39fbb802684873a
--- /dev/null
+++ b/mmpretrain/models/backbones/resnet.py
@@ -0,0 +1,768 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule
+from mmengine.model.weight_init import constant_init
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
+
+from mmpretrain.registry import MODELS
+from .base_backbone import BaseBackbone
+
+eps = 1.0e-5
+
+
+class BasicBlock(BaseModule):
+ """BasicBlock for ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the output channels of conv1. This is a
+ reserved argument in BasicBlock and should always be 1. Default: 1.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module, optional): downsample operation on identity
+ branch. Default: None.
+ style (str): `pytorch` or `caffe`. It is unused and reserved for
+ unified API with Bottleneck.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict, optional): dictionary to construct and config conv
+ layer. Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=1,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ drop_path_rate=0.0,
+ act_cfg=dict(type='ReLU', inplace=True),
+ init_cfg=None):
+ super(BasicBlock, self).__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert self.expansion == 1
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, out_channels, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = build_activation_layer(act_cfg)
+ self.downsample = downsample
+ self.drop_path = DropPath(drop_prob=drop_path_rate
+ ) if drop_path_rate > eps else nn.Identity()
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.drop_path(out)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(BaseModule):
+ """Bottleneck block for ResNet.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int): The ratio of ``out_channels/mid_channels`` where
+ ``mid_channels`` is the input/output channels of conv2. Default: 4.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module, optional): downsample operation on identity
+ branch. Default: None.
+ style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
+ stride-two layer is the 3x3 conv layer, otherwise the stride-two
+ layer is the first 1x1 conv layer. Default: "pytorch".
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ conv_cfg (dict, optional): dictionary to construct and config conv
+ layer. Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=4,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ drop_path_rate=0.0,
+ init_cfg=None):
+ super(Bottleneck, self).__init__(init_cfg=init_cfg)
+ assert style in ['pytorch', 'caffe']
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.expansion = expansion
+ assert out_channels % expansion == 0
+ self.mid_channels = out_channels // expansion
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ self.mid_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = build_activation_layer(act_cfg)
+ self.downsample = downsample
+ self.drop_path = DropPath(drop_prob=drop_path_rate
+ ) if drop_path_rate > eps else nn.Identity()
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.drop_path(out)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def get_expansion(block, expansion=None):
+ """Get the expansion of a residual block.
+
+ The block expansion will be obtained by the following order:
+
+ 1. If ``expansion`` is given, just return it.
+ 2. If ``block`` has the attribute ``expansion``, then return
+ ``block.expansion``.
+ 3. Return the default value according the the block type:
+ 1 for ``BasicBlock`` and 4 for ``Bottleneck``.
+
+ Args:
+ block (class): The block class.
+ expansion (int | None): The given expansion ratio.
+
+ Returns:
+ int: The expansion of the block.
+ """
+ if isinstance(expansion, int):
+ assert expansion > 0
+ elif expansion is None:
+ if hasattr(block, 'expansion'):
+ expansion = block.expansion
+ elif issubclass(block, BasicBlock):
+ expansion = 1
+ elif issubclass(block, Bottleneck):
+ expansion = 4
+ else:
+ raise TypeError(f'expansion is not specified for {block.__name__}')
+ else:
+ raise TypeError('expansion must be an integer or None')
+
+ return expansion
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): Residual block used to build ResLayer.
+ num_blocks (int): Number of blocks.
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ expansion (int, optional): The expansion for BasicBlock/Bottleneck.
+ If not specified, it will firstly be obtained via
+ ``block.expansion``. If the block has no attribute "expansion",
+ the following default values will be used: 1 for BasicBlock and
+ 4 for Bottleneck. Default: None.
+ stride (int): stride of the first block. Default: 1.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict, optional): dictionary to construct and config conv
+ layer. Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ drop_path_rate (float or list): stochastic depth rate.
+ Default: 0.
+ """
+
+ def __init__(self,
+ block,
+ num_blocks,
+ in_channels,
+ out_channels,
+ expansion=None,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ drop_path_rate=0.0,
+ **kwargs):
+ self.block = block
+ self.expansion = get_expansion(block, expansion)
+
+ if isinstance(drop_path_rate, float):
+ drop_path_rate = [drop_path_rate] * num_blocks
+
+ assert len(drop_path_rate
+ ) == num_blocks, 'Please check the length of drop_path_rate'
+
+ downsample = None
+ if stride != 1 or in_channels != out_channels:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, out_channels)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ drop_path_rate=drop_path_rate[0],
+ **kwargs))
+ in_channels = out_channels
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ expansion=self.expansion,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ drop_path_rate=drop_path_rate[i],
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
+
+
+@MODELS.register_module()
+class ResNet(BaseBackbone):
+ """ResNet backbone.
+
+ Please refer to the `paper `__ for
+ details.
+
+ Args:
+ depth (int): Network depth, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ base_channels (int): Middle channels of the first stage. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+
+ Example:
+ >>> from mmpretrain.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ expansion=None,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(3, ),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=True,
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d']),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ],
+ drop_path_rate=0.0):
+ super(ResNet, self).__init__(init_cfg)
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.expansion = get_expansion(self.block, expansion)
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ _in_channels = stem_channels
+ _out_channels = base_channels * self.expansion
+
+ # stochastic depth decay rule
+ total_depth = sum(stage_blocks)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ]
+
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ res_layer = self.make_res_layer(
+ block=self.block,
+ num_blocks=num_blocks,
+ in_channels=_in_channels,
+ out_channels=_out_channels,
+ expansion=self.expansion,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ drop_path_rate=dpr[:num_blocks])
+ _in_channels = _out_channels
+ _out_channels *= 2
+ dpr = dpr[num_blocks:]
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = res_layer[-1].out_channels
+
+ def make_res_layer(self, **kwargs):
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ super(ResNet, self).init_weights()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress zero_init_residual if use pretrained model.
+ return
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+
+ def forward(self, x):
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def get_layer_depth(self, param_name: str, prefix: str = ''):
+ """Get the layer id to set the different learning rates for ResNet.
+
+ ResNet stages:
+ 50 : [3, 4, 6, 3]
+ 101 : [3, 4, 23, 3]
+ 152 : [3, 8, 36, 3]
+ 200 : [3, 24, 36, 3]
+ eca269d: [3, 30, 48, 8]
+
+ Args:
+ param_name (str): The name of the parameter.
+ prefix (str): The prefix for the parameter.
+ Defaults to an empty string.
+
+ Returns:
+ Tuple[int, int]: The layer-wise depth and the num of layers.
+ """
+ depths = self.stage_blocks
+ if depths[1] == 4 and depths[2] == 6:
+ blk2, blk3 = 2, 3
+ elif depths[1] == 4 and depths[2] == 23:
+ blk2, blk3 = 2, 3
+ elif depths[1] == 8 and depths[2] == 36:
+ blk2, blk3 = 4, 4
+ elif depths[1] == 24 and depths[2] == 36:
+ blk2, blk3 = 4, 4
+ elif depths[1] == 30 and depths[2] == 48:
+ blk2, blk3 = 5, 6
+ else:
+ raise NotImplementedError
+
+ N2, N3 = math.ceil(depths[1] / blk2 -
+ 1e-5), math.ceil(depths[2] / blk3 - 1e-5)
+ N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6
+ max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7
+
+ if not param_name.startswith(prefix):
+ # For subsequent module like head
+ return max_layer_id, max_layer_id + 1
+
+ if param_name.startswith('backbone.layer'):
+ stage_id = int(param_name.split('.')[1][5:])
+ block_id = int(param_name.split('.')[2])
+
+ if stage_id == 1:
+ layer_id = 1
+ elif stage_id == 2:
+ layer_id = 2 + block_id // blk2 # r50: 2, 3
+ elif stage_id == 3:
+ layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5
+ else: # stage_id == 4
+ layer_id = N # r50: 6
+ return layer_id, max_layer_id + 1
+
+ else:
+ return 0, max_layer_id + 1
+
+
+@MODELS.register_module()
+class ResNetV1c(ResNet):
+ """ResNetV1c backbone.
+
+ This variant is described in `Bag of Tricks.
+ `_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+ in the input stem with three 3x3 convs.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1c, self).__init__(
+ deep_stem=True, avg_down=False, **kwargs)
+
+
+@MODELS.register_module()
+class ResNetV1d(ResNet):
+ """ResNetV1d backbone.
+
+ This variant is described in `Bag of Tricks.
+ `_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/mmpretrain/models/backbones/resnet_cifar.py b/mmpretrain/models/backbones/resnet_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f17f92fd76a690ea90977b38ab2ea00345ba903
--- /dev/null
+++ b/mmpretrain/models/backbones/resnet_cifar.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpretrain.registry import MODELS
+from .resnet import ResNet
+
+
+@MODELS.register_module()
+class ResNet_CIFAR(ResNet):
+ """ResNet backbone for CIFAR.
+
+ Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in
+ conv1, and does not apply MaxPoolinng after stem. It has been proven to
+ be more efficient than standard ResNet in other public codebase, e.g.,
+ `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`.
+
+ Args:
+ depth (int): Network depth, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int): Output channels of the stem layer. Default: 64.
+ base_channels (int): Middle channels of the first stage. Default: 64.
+ num_stages (int): Stages of the network. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: ``(1, 2, 2, 2)``.
+ dilations (Sequence[int]): Dilation of each stage.
+ Default: ``(1, 1, 1, 1)``.
+ out_indices (Sequence[int]): Output from which stages. If only one
+ stage is specified, a single tensor (feature map) is returned,
+ otherwise multiple stages are specified, a tuple of tensors will
+ be returned. Default: ``(3, )``.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): This network has specific designed stem, thus it is
+ asserted to be False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict | None): The config dict for conv layers. Default: None.
+ norm_cfg (dict): The config dict for norm layers.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: True.
+ """
+
+ def __init__(self, depth, deep_stem=False, **kwargs):
+ super(ResNet_CIFAR, self).__init__(
+ depth, deep_stem=deep_stem, **kwargs)
+ assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem'
+
+ def _make_stem_layer(self, in_channels, base_channels):
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, base_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmpretrain/models/backbones/resnext.py b/mmpretrain/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..8858b7d3dffdcb20677e091fba4f5a1084d086a3
--- /dev/null
+++ b/mmpretrain/models/backbones/resnext.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from mmpretrain.registry import MODELS
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResLayer, ResNet
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+
+ Args:
+ in_channels (int): Input channels of this block.
+ out_channels (int): Output channels of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ stride (int): stride of the block. Default: 1
+ dilation (int): dilation of convolution. Default: 1
+ downsample (nn.Module, optional): downsample operation on identity
+ branch. Default: None
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ conv_cfg (dict, optional): dictionary to construct and config conv
+ layer. Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ base_channels=64,
+ groups=32,
+ width_per_group=4,
+ **kwargs):
+ super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ # For ResNet bottleneck, middle channels are determined by expansion
+ # and out_channels, but for ResNeXt bottleneck, it is determined by
+ # groups and width_per_group and the stage it is located in.
+ if groups != 1:
+ assert self.mid_channels % base_channels == 0
+ self.mid_channels = (
+ groups * width_per_group * self.mid_channels // base_channels)
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, self.mid_channels, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.out_channels, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.in_channels,
+ self.mid_channels,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.mid_channels,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ self.mid_channels,
+ self.out_channels,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@MODELS.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Please refer to the `paper